test_tws_client.py•12.1 kB
import pytest
import json
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from ib_async import IB, Position, Contract, util
from src.tws_client import TWSClient
from src.models import ContractRequest
# Fixture to load JSON data
@pytest.fixture
def load_fixture():
def _loader(filename):
with open(f"tests/fixtures/{filename}", "r") as f:
return json.load(f)
return _loader
# Mock object for ib_insync.IB
class MockIB:
"""Mock IB class for testing."""
def __init__(self):
self.isConnected = MagicMock(return_value=False)
self.connectAsync = AsyncMock()
self.disconnect = MagicMock()
self.positions = MagicMock(return_value=[]) # Sync method
self.positionsAsync = AsyncMock(return_value=[])
self.reqContractDetailsAsync = AsyncMock(return_value=[])
self.qualifyContractsAsync = AsyncMock(return_value=[])
self.reqHistoricalDataAsync = AsyncMock(return_value=[])
self.accountSummaryAsync = AsyncMock(return_value=[])
self.placeOrder = MagicMock() # Sync method
self.placeOrderAsync = AsyncMock()
self.cancelOrderAsync = AsyncMock()
self.reqAllOpenOrdersAsync = AsyncMock(return_value=[])
self.reqExecutionsAsync = AsyncMock(return_value=[])
self.reqMktData = MagicMock()
self.reqMarketDataType = MagicMock() # Set market data type (1=live, 2=frozen, 3=delayed)
self.cancelMktData = MagicMock()
self.waitOnUpdate = AsyncMock()
self.reqAccountUpdates = MagicMock()
self.accountValues = MagicMock(return_value=[])
# Mock errorEvent - needs to support += and -= operators
self.errorEvent = MagicMock()
self.errorEvent.__iadd__ = MagicMock(return_value=self.errorEvent)
self.errorEvent.__isub__ = MagicMock(return_value=self.errorEvent)
# Patch the IB class in src.tws_client with our mock
@pytest.fixture(autouse=True)
def mock_ib_patch():
with patch("src.tws_client.IB", new=MockIB):
yield
@pytest.mark.asyncio
async def test_connect_success():
"""Test successful connection to TWS."""
client = TWSClient()
# Initialize mock IB instance (since __init__ no longer creates one)
client.ib = MockIB()
client.ib.isConnected.return_value = False
client.ib.connectAsync.return_value = None
result = await client.connect("127.0.0.1", 7496, 2)
assert result is True
client.ib.connectAsync.assert_called_once_with("127.0.0.1", 7496, clientId=2, timeout=5)
@pytest.mark.asyncio
async def test_connect_already_connected():
"""Test connection when already connected."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
result = await client.connect("127.0.0.1", 7496, 1)
assert result is True
client.ib.connectAsync.assert_not_called()
@pytest.mark.asyncio
async def test_get_positions(load_fixture):
"""Test getting positions."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Load fixture data
fixture_data = load_fixture('sample_positions.json')
# Convert fixture data to mock ib_async Position objects
mock_positions = []
for data in fixture_data:
mock_pos = Position(
data['account'],
Contract(**data['contract']),
data['position'],
data['avgCost']
)
mock_positions.append(mock_pos)
client.ib.positions.return_value = mock_positions # Use sync method
positions = await client.get_positions()
assert len(positions) == 2
assert positions[0]['contract']['symbol'] == 'VTI'
assert positions[0]['position'] == 100.0
assert positions[1]['contract']['symbol'] == 'TLT'
@pytest.mark.asyncio
async def test_get_historical_data():
"""Test getting historical data."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Mock return values
mock_contract = Contract(conId=123)
client.ib.qualifyContractsAsync.return_value = [mock_contract]
# Mock historical data bars (simplified)
mock_bars = [
MagicMock(date="20240101", open=100.0, high=101.0, low=99.0, close=100.5, volume=1000),
MagicMock(date="20240102", open=100.5, high=102.0, low=100.0, close=101.5, volume=1500),
]
client.ib.reqHistoricalDataAsync.return_value = mock_bars
req = ContractRequest(symbol="AAPL")
data = await client.get_historical_data(req, "1 D", "1 min", "TRADES")
assert len(data) == 2
assert data[0]['open'] == 100.0
assert data[1]['close'] == 101.5
client.ib.reqHistoricalDataAsync.assert_called_once()
@pytest.mark.asyncio
async def test_place_market_order():
"""Test placing a market order."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Mock return values
mock_contract = Contract(conId=123, symbol='AAPL')
client.ib.qualifyContractsAsync.return_value = [mock_contract]
# Mock trade object (simplified)
mock_trade = MagicMock()
mock_trade.order.orderId = 1
mock_trade.orderStatus.status = 'Submitted'
mock_trade.contract = mock_contract
mock_trade.order.action = 'BUY'
mock_trade.order.totalQuantity = 10
client.ib.placeOrder.return_value = mock_trade # Use sync method
req = MagicMock(
contract=ContractRequest(symbol="AAPL"),
action="BUY",
totalQuantity=10,
orderType="MKT",
lmtPrice=None
)
result = await client.place_order(req)
assert result['orderId'] == 1
assert result['status'] == 'Submitted'
assert result['action'] == 'BUY'
client.ib.placeOrder.assert_called_once() # Check sync method
@pytest.mark.asyncio
async def test_stream_market_data_generator():
"""Test market data streaming generator."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Mock return values
mock_contract = Contract(conId=123, symbol='SPY')
client.ib.qualifyContractsAsync.return_value = [mock_contract]
# Counter to control how many updates we get
update_count = {'count': 0}
# Mock ticker object with updateEvent
# eventkit.Event has a timeout() method that returns an async iterable
async def mock_timeout_gen(timeout_val):
"""Mock async generator for timeout() that yields limited times then raises CancelledError"""
update_count['count'] += 1
if update_count['count'] <= 2:
yield None # Yield to indicate update is ready
else:
# After 2 updates, hang forever (will be cancelled)
await asyncio.Future() # This will hang until cancelled
mock_update_event = MagicMock()
mock_update_event.timeout = MagicMock(side_effect=lambda t: mock_timeout_gen(t))
mock_ticker = MagicMock(
time=MagicMock(isoformat=MagicMock(return_value="2024-01-01T10:00:00")),
last=100.0, bid=99.9, ask=100.1, volume=1000,
bidSize=10, askSize=10, close=99.5,
updateEvent=mock_update_event
)
client.ib.reqMktData.return_value = mock_ticker
req = ContractRequest(symbol="SPY")
updates = []
gen = client.stream_market_data(req)
try:
async for data in gen:
updates.append(data)
if len(updates) >= 2:
# After collecting 2 updates, close the generator
await gen.aclose()
break
except asyncio.CancelledError:
pass
# The generator yields twice before cancellation
assert len(updates) == 2
assert updates[0]['last'] == 100.0
client.ib.cancelMktData.assert_called_once()
@pytest.mark.asyncio
async def test_stream_market_data_with_warning_codes():
"""Test that warning codes (like 10167 for delayed data) don't stop streaming."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Mock return values
mock_contract = Contract(conId=456, symbol='AAPL')
client.ib.qualifyContractsAsync.return_value = [mock_contract]
# Counter to control how many updates we get
update_count = {'count': 0}
# Track error handler calls
error_handler_calls = []
original_iadd = client.ib.errorEvent.__iadd__
def track_error_handler(handler):
"""Track when error handlers are added"""
error_handler_calls.append(handler)
# Simulate a warning (10167 = delayed data warning)
# This should NOT stop the stream
handler(4, 10167, "Requested market data is not subscribed. Displaying delayed market data.", mock_contract)
return original_iadd(handler)
client.ib.errorEvent.__iadd__ = track_error_handler
# Mock ticker object with updateEvent
async def mock_timeout_gen(timeout_val):
"""Mock async generator for timeout() that yields limited times"""
update_count['count'] += 1
if update_count['count'] <= 2:
yield None # Yield to indicate update is ready
else:
await asyncio.Future() # This will hang until cancelled
mock_update_event = MagicMock()
mock_update_event.timeout = MagicMock(side_effect=lambda t: mock_timeout_gen(t))
mock_ticker = MagicMock(
time=MagicMock(isoformat=MagicMock(return_value="2024-01-01T10:00:00")),
last=150.0, bid=149.9, ask=150.1, volume=5000,
bidSize=20, askSize=20, close=149.5,
updateEvent=mock_update_event
)
client.ib.reqMktData.return_value = mock_ticker
req = ContractRequest(symbol="AAPL")
updates = []
gen = client.stream_market_data(req)
try:
async for data in gen:
updates.append(data)
if len(updates) >= 2:
await gen.aclose()
break
except asyncio.CancelledError:
pass
# Despite the warning (error code 10167), streaming should continue
assert len(updates) == 2
assert updates[0]['last'] == 150.0
# Error handler should have been registered
assert len(error_handler_calls) == 1
client.ib.cancelMktData.assert_called_once()
@pytest.mark.asyncio
async def test_stream_market_data_with_real_error():
"""Test that real errors (not warnings) properly stop streaming."""
client = TWSClient()
# Initialize mock IB instance
client.ib = MockIB()
client.ib.isConnected.return_value = True
# Mock return values
mock_contract = Contract(conId=789, symbol='MSFT')
client.ib.qualifyContractsAsync.return_value = [mock_contract]
# Track error handler calls
error_handler_calls = []
original_iadd = client.ib.errorEvent.__iadd__
def track_error_handler(handler):
"""Track when error handlers are added and simulate a real error"""
error_handler_calls.append(handler)
# Simulate a real error (e.g., 10089 = subscription required)
# This SHOULD stop the stream during the initial check
handler(5, 10089, "Requested market data requires additional subscription.", mock_contract)
return original_iadd(handler)
client.ib.errorEvent.__iadd__ = track_error_handler
# Mock ticker
mock_ticker = MagicMock()
client.ib.reqMktData.return_value = mock_ticker
req = ContractRequest(symbol="MSFT")
gen = client.stream_market_data(req)
# Should raise RuntimeError due to the error code 10089
with pytest.raises(RuntimeError) as exc_info:
async for data in gen:
pass
assert "10089" in str(exc_info.value)
assert "additional subscription" in str(exc_info.value)