tws_client.py•20 kB
from ib_async import IB, Stock, Option, Future, Contract, MarketOrder, LimitOrder, util
from typing import List, Dict, Any, Optional, AsyncGenerator
import asyncio
from src.models import ContractRequest, OrderRequest
def _to_dict(obj):
"""Safely convert dataclass-like objects to dicts for tests and runtime.
util.dataclassAsDict raises TypeError for MagicMocks used in unit tests.
This helper falls back to __dict__ or attribute extraction when needed.
"""
try:
return util.dataclassAsDict(obj)
except TypeError:
# fall back to __dict__ if available
if hasattr(obj, '__dict__'):
return dict(obj.__dict__)
# try to extract common fields as a last resort
result = {}
for attr in ('date', 'time', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'conId', 'last'):
if hasattr(obj, attr):
value = getattr(obj, attr)
# if value is a MagicMock with isoformat, call it
try:
if hasattr(value, 'isoformat'):
value = value.isoformat()
except Exception:
pass
result[attr] = value
if result:
return result
# give up - return a repr to avoid breaking callers
return {"repr": repr(obj)}
class TWSClient:
def __init__(self):
# Do NOT create IB instance here to avoid capturing the wrong event loop.
# The IB instance will be created in connect() in the correct event loop context.
# We set it to None initially, and unit tests will need to set it up in their fixtures.
self.ib: Optional[IB] = None
self._connected = False
self._market_data_subscriptions = {}
def is_connected(self) -> bool:
"""Check if the client is connected to TWS."""
return bool(self.ib and self.ib.isConnected())
async def connect(self, host: str, port: int, client_id: int) -> bool:
"""Connect to TWS/IB Gateway."""
if self.is_connected():
return True
# CRITICAL: Disconnect and completely destroy any existing IB instance
# to prevent event loop capture issues
if self.ib is not None:
try:
# Forcefully disconnect
if hasattr(self.ib, 'disconnect'):
self.ib.disconnect()
# Wait a moment for cleanup
await asyncio.sleep(0.1)
except Exception:
pass
finally:
# Completely remove the reference
del self.ib
self.ib = None
# Create a completely fresh IB instance in the current event loop context
# This is the ONLY way to ensure the IB instance is bound to the correct loop
self.ib = IB()
try:
# Use a shorter timeout to avoid MCP session cancellation
# ib_async connectAsync is a coroutine
await asyncio.wait_for(
self.ib.connectAsync(host, port, clientId=client_id, timeout=5),
timeout=8.0 # Overall timeout slightly longer than connectAsync timeout
)
self._connected = True
return True
except asyncio.TimeoutError:
# Cleanup on timeout
try:
if self.ib:
self.ib.disconnect()
except Exception:
pass
raise ConnectionError(
f"Connection timeout: Could not connect to TWS at {host}:{port} within 8 seconds. "
f"Please ensure TWS/IB Gateway is running and accepting connections on this port."
)
except asyncio.CancelledError:
# Request was cancelled by the client
try:
if self.ib:
self.ib.disconnect()
except Exception:
pass
raise ConnectionError(
f"Connection cancelled: The connection request to {host}:{port} was cancelled. "
f"This may indicate a client timeout or disconnection."
)
except Exception as e:
# Other connection errors
try:
if self.ib and self.ib.isConnected():
self.ib.disconnect()
except Exception:
pass
raise ConnectionError(
f"Failed to connect to TWS at {host}:{port}: {type(e).__name__}: {str(e)}"
)
def disconnect(self):
"""Disconnect from TWS/IB Gateway."""
if self.ib and self.ib.isConnected():
try:
self.ib.disconnect()
except Exception:
pass
self._connected = False
def _create_contract(self, req: ContractRequest) -> Contract:
"""Helper to create an ib_insync Contract object."""
if req.secType == "STK":
return Stock(req.symbol, req.exchange, req.currency)
elif req.secType == "OPT":
# Simplified for now, full implementation would require more fields
return Option(req.symbol, exchange=req.exchange, currency=req.currency)
elif req.secType == "FUT":
# Simplified for now
return Future(req.symbol, exchange=req.exchange, currency=req.currency)
else:
return Contract(symbol=req.symbol, secType=req.secType, exchange=req.exchange, currency=req.currency)
async def get_contract_details(self, req: ContractRequest) -> List[Dict[str, Any]]:
"""Get contract details for a given contract request."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
contract = self._create_contract(req)
details = await self.ib.reqContractDetailsAsync(contract)
return [_to_dict(cd) for cd in details]
async def get_historical_data(self, req: ContractRequest, durationStr: str, barSizeSetting: str, whatToShow: str) -> List[Dict[str, Any]]:
"""Get historical market data."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
contract = self._create_contract(req)
# Validate that contract was created successfully
if contract is None:
raise ValueError(f"Failed to create contract for {req.symbol} (secType={req.secType})")
# Qualify the contract first
qualified = await self.ib.qualifyContractsAsync(contract)
if not qualified:
raise ValueError(f"Contract not found or could not be qualified: {req.symbol} (secType={req.secType}, exchange={req.exchange})")
# Ensure the qualified contract is valid
qualified_contract = qualified[0]
if qualified_contract is None:
raise ValueError(f"Contract qualification returned None for {req.symbol} (secType={req.secType}, exchange={req.exchange})")
bars = await self.ib.reqHistoricalDataAsync(
qualified_contract,
endDateTime='',
durationStr=durationStr,
barSizeSetting=barSizeSetting,
whatToShow=whatToShow,
useRTH=1,
formatDate=1
)
return [_to_dict(bar) for bar in bars]
async def get_account_summary(self) -> List[Dict[str, Any]]:
"""Get account summary values."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
summary = await self.ib.accountSummaryAsync()
return [_to_dict(item) for item in summary]
async def get_positions(self) -> List[Dict[str, Any]]:
"""Get current portfolio positions."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
positions = self.ib.positions()
return [
{
"account": pos.account,
"contract": _to_dict(pos.contract),
"position": pos.position,
"avgCost": pos.avgCost,
}
for pos in positions
]
async def place_order(self, req: OrderRequest) -> Dict[str, Any]:
"""Place an order."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
contract = self._create_contract(req.contract)
# Qualify the contract
qualified = await self.ib.qualifyContractsAsync(contract)
if not qualified:
raise ValueError(f"Contract not found or could not be qualified: {req.contract.symbol}")
contract = qualified[0]
# Create order
if req.orderType == "MKT":
order = MarketOrder(req.action, req.totalQuantity, transmit=req.transmit)
elif req.orderType == "LMT":
if req.lmtPrice is None:
raise ValueError("limitPrice is required for LMT orders")
order = LimitOrder(req.action, req.totalQuantity, req.lmtPrice, transmit=req.transmit)
else:
raise ValueError(f"Unsupported order type: {req.orderType}")
# Place order
trade = self.ib.placeOrder(contract, order)
# Wait a moment for status update
await asyncio.sleep(0.5)
# Return trade details
return {
"orderId": trade.order.orderId,
"status": trade.orderStatus.status,
"contract": _to_dict(trade.contract),
"action": trade.order.action,
"quantity": trade.order.totalQuantity,
}
async def cancel_order(self, orderId: int) -> Dict[str, Any]:
"""Cancel an order by ID."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
order = await self.ib.reqOpenOrdersAsync()
order_to_cancel = next((o for o in order if o.order.orderId == orderId), None)
if not order_to_cancel:
raise ValueError(f"Order with ID {orderId} not found among open orders.")
trade = self.ib.cancelOrder(order_to_cancel.order)
# Wait a moment for status update
await asyncio.sleep(0.5)
return {
"orderId": trade.order.orderId,
"status": trade.orderStatus.status,
"message": f"Cancellation request sent for order {orderId}"
}
async def get_open_orders(self) -> List[Dict[str, Any]]:
"""Get all open orders."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
trades = await self.ib.reqAllOpenOrdersAsync()
return [
{
"orderId": t.order.orderId,
"status": t.orderStatus.status,
"contract": _to_dict(t.contract),
"action": t.order.action,
"quantity": t.order.totalQuantity,
}
for t in trades
]
async def get_executions(self) -> List[Dict[str, Any]]:
"""Get all executions."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
executions = await self.ib.reqExecutionsAsync()
return [_to_dict(e) for e in executions]
async def get_pnl(self, account: str, modelCode: str) -> Dict[str, Any]:
"""Get overall Profit and Loss."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
pnl = self.ib.reqPnL(account, modelCode)
return _to_dict(pnl)
async def get_pnl_single(self, account: str, modelCode: str, conId: int) -> Dict[str, Any]:
"""Get PnL for a single account/model."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
pnl = self.ib.reqPnLSingle(account, modelCode, conId)
return _to_dict(pnl)
async def stream_market_data(self, req: ContractRequest) -> AsyncGenerator[Dict[str, Any], None]:
"""Stream real-time market data."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
contract = self._create_contract(req)
# Qualify the contract first
qualified = await self.ib.qualifyContractsAsync(contract)
if not qualified:
raise ValueError(f"Contract not found or could not be qualified: {req.symbol}")
contract = qualified[0]
# Use the contract's conId as the key
con_id = contract.conId
if con_id in self._market_data_subscriptions:
raise RuntimeError(f"Market data already streaming for contract ID {con_id}")
# Request market data
self.ib.reqMarketDataType(3) # 3 = delayed data
ticker = self.ib.reqMktData(contract, '', False, False)
self._market_data_subscriptions[con_id] = ticker
# Set up error tracking
# Warning codes from ib_async/wrapper.py - these are informational and should not stop streaming
# 10167: "Requested market data is not subscribed. Displaying delayed market data."
# Other warning codes: 105, 110, 165, 321, 329, 399, 404, 434, 492, and 2100-2199 range
WARNING_CODES = frozenset({105, 110, 165, 321, 329, 399, 404, 434, 492, 10167})
error_occurred = []
def on_error(reqId, errorCode, errorString, contract):
"""Callback for TWS errors related to this request"""
# Check if this error is for our ticker
if contract and hasattr(contract, 'conId') and contract.conId == con_id:
# Filter out warnings - these are informational only
is_warning = errorCode in WARNING_CODES or 2100 <= errorCode < 2200
if not is_warning:
error_occurred.append({
'reqId': reqId,
'errorCode': errorCode,
'errorString': errorString,
'contract': str(contract)
})
# Connect to error event
self.ib.errorEvent += on_error
try:
# Check for immediate errors (like missing market data subscription)
await asyncio.sleep(0.5) # Give TWS time to send error if any
if error_occurred:
error = error_occurred[0]
raise RuntimeError(
f"TWS Error {error['errorCode']}: {error['errorString']} "
f"(reqId: {error['reqId']}, contract: {error['contract']})"
)
while True:
# Check for errors that occurred during streaming
if error_occurred:
error = error_occurred[0]
raise RuntimeError(
f"TWS Error {error['errorCode']}: {error['errorString']} "
f"(reqId: {error['reqId']}, contract: {error['contract']})"
)
# Wait for the next market data update using eventkit's timeout
# ticker.updateEvent is an eventkit.Event that can be awaited or iterated
try:
# Use async iteration with timeout
# This will wait up to 2 seconds for an update
async for _ in ticker.updateEvent.timeout(2.0):
# Got an update, break to process it
break
except asyncio.TimeoutError:
# Timeout is OK - just means no update, yield empty dict
yield {}
continue
# Check if the ticker has a price update
if ticker.time and ticker.last:
yield {
"time": ticker.time.isoformat(),
"last": ticker.last,
"bid": ticker.bid,
"ask": ticker.ask,
"volume": ticker.volume,
"bidSize": ticker.bidSize,
"askSize": ticker.askSize,
"close": ticker.close,
}
else:
# No meaningful update yet
yield {}
except asyncio.CancelledError:
# Clean up when the generator is closed
self.ib.errorEvent -= on_error # Disconnect error handler
self.ib.cancelMktData(contract)
del self._market_data_subscriptions[con_id]
raise
except GeneratorExit:
# Clean up when the generator is explicitly closed via aclose()
self.ib.errorEvent -= on_error # Disconnect error handler
self.ib.cancelMktData(contract)
del self._market_data_subscriptions[con_id]
raise
except Exception as e:
# Clean up on other errors
self.ib.errorEvent -= on_error # Disconnect error handler
if con_id in self._market_data_subscriptions:
self.ib.cancelMktData(contract)
del self._market_data_subscriptions[con_id]
raise
finally:
# Always disconnect the error handler
try:
self.ib.errorEvent -= on_error
except Exception:
pass
async def stream_account_updates(self, account: str) -> AsyncGenerator[Dict[str, Any], None]:
"""Stream real-time account updates."""
if not self.is_connected():
raise RuntimeError("Not connected to TWS")
# Start account updates subscription
self.ib.reqAccountUpdates(account)
try:
while True:
# Wait for updates using the IB updateEvent with timeout
try:
# Use async iteration with timeout (eventkit.Event pattern)
async for _ in self.ib.updateEvent.timeout(5.0):
# Got an update, break to process it
break
except asyncio.TimeoutError:
# No update within timeout - yield empty to keep connection alive
yield {}
continue
# Check for position updates
positions = self.ib.positions()
if positions:
yield {
"type": "positions",
"data": [
{
"account": pos.account,
"contract": _to_dict(pos.contract),
"position": pos.position,
"avgCost": pos.avgCost,
}
for pos in positions
]
}
# Check for account value updates (simplified)
account_values = self.ib.accountValues()
if account_values:
yield {
"type": "account_values",
"data": [
_to_dict(item)
for item in account_values
]
}
# Yield an empty dict if no update, to keep the connection alive
else:
yield {}
except asyncio.CancelledError:
# Clean up when the generator is closed
self.ib.reqAccountUpdates(account)
raise
except Exception as e:
# Clean up on other errors
self.ib.reqAccountUpdates(account)
raise