server.py•8.78 kB
import os
import asyncio
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncIterator, Dict, Any, List, Optional
from starlette.applications import Starlette
from starlette.routing import Mount
from starlette.middleware.cors import CORSMiddleware
from .tws_client import TWSClient
from .models import ContractRequest, OrderRequest
# Load environment variables from .env file
load_dotenv()
@dataclass
class AppContext:
"""Application context with TWS client."""
tws: TWSClient
@asynccontextmanager
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
"""Manage TWS client lifecycle."""
tws = TWSClient()
try:
# TWS client is initialized but not connected here. Connection is done via the ibkr_connect tool.
yield AppContext(tws=tws)
finally:
# Ensure TWS client is disconnected on shutdown
if tws.is_connected():
tws.disconnect()
# Create MCP server
mcp = FastMCP("IBKR TWS MCP Server", lifespan=app_lifespan)
# --- Connection Management Tools ---
@mcp.tool()
async def ibkr_connect(
ctx: Context[ServerSession, AppContext],
host: str = os.getenv("TWS_HOST", "127.0.0.1"),
port: int = int(os.getenv("TWS_PORT", 7496)),
clientId: int = int(os.getenv("TWS_CLIENT_ID", 1))
) -> Dict[str, Any]:
"""Connect to TWS/IB Gateway."""
tws = ctx.request_context.lifespan_context.tws
await tws.connect(host, port, clientId)
return {"status": "connected", "host": host, "port": port, "clientId": clientId}
@mcp.tool()
async def ibkr_disconnect(
ctx: Context[ServerSession, AppContext]
) -> Dict[str, Any]:
"""Disconnect from TWS/IB Gateway."""
tws = ctx.request_context.lifespan_context.tws
tws.disconnect()
return {"status": "disconnected"}
@mcp.tool()
async def ibkr_get_status(
ctx: Context[ServerSession, AppContext]
) -> Dict[str, Any]:
"""Get connection status."""
tws = ctx.request_context.lifespan_context.tws
return {"is_connected": tws.is_connected()}
# --- Contract and Market Data Tools ---
@mcp.tool()
async def ibkr_get_contract_details(
ctx: Context[ServerSession, AppContext],
symbol: str,
secType: str = "STK",
exchange: str = "SMART",
currency: str = "USD"
) -> List[Dict[str, Any]]:
"""Get contract details for a given symbol."""
tws = ctx.request_context.lifespan_context.tws
req = ContractRequest(symbol=symbol, secType=secType, exchange=exchange, currency=currency)
return await tws.get_contract_details(req)
@mcp.tool()
async def ibkr_get_historical_data(
ctx: Context[ServerSession, AppContext],
symbol: str,
secType: str = "STK",
exchange: str = "SMART",
currency: str = "USD",
durationStr: str = "1 Y",
barSizeSetting: str = "1 day",
whatToShow: str = "TRADES"
) -> List[Dict[str, Any]]:
"""Get historical market data for a contract."""
tws = ctx.request_context.lifespan_context.tws
req = ContractRequest(symbol=symbol, secType=secType, exchange=exchange, currency=currency)
return await tws.get_historical_data(req, durationStr, barSizeSetting, whatToShow)
@mcp.tool()
async def ibkr_stream_market_data(
ctx: Context[ServerSession, AppContext],
symbol: str,
secType: str = "STK",
exchange: str = "SMART",
currency: str = "USD",
duration_seconds: int = 60 # Added for a practical streaming limit
) -> Dict[str, Any]:
"""Stream real-time market data for a symbol for a given duration."""
tws = ctx.request_context.lifespan_context.tws
req = ContractRequest(symbol=symbol, secType=secType, exchange=exchange, currency=currency)
start_time = asyncio.get_event_loop().time()
updates_count = 0
async for data in tws.stream_market_data(req):
if data:
await ctx.info(f"Market data update: {data}")
updates_count += 1
if (asyncio.get_event_loop().time() - start_time) >= duration_seconds:
break
return {
"symbol": symbol,
"updates_count": updates_count,
"message": f"Streaming completed after {duration_seconds} seconds."
}
# --- Account and Portfolio Tools ---
@mcp.tool()
async def ibkr_get_account_summary(
ctx: Context[ServerSession, AppContext]
) -> List[Dict[str, Any]]:
"""Get overall account summary metrics."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_account_summary()
@mcp.tool()
async def ibkr_get_positions(
ctx: Context[ServerSession, AppContext]
) -> List[Dict[str, Any]]:
"""Get current portfolio positions."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_positions()
@mcp.tool()
async def ibkr_stream_account_updates(
ctx: Context[ServerSession, AppContext],
account: str,
duration_seconds: int = 60 # Added for a practical streaming limit
) -> Dict[str, Any]:
"""Stream real-time account and position updates for a given duration."""
tws = ctx.request_context.lifespan_context.tws
start_time = asyncio.get_event_loop().time()
updates_count = 0
async for data in tws.stream_account_updates(account):
if data:
await ctx.info(f"Account update: {data}")
updates_count += 1
if (asyncio.get_event_loop().time() - start_time) >= duration_seconds:
break
return {
"account": account,
"updates_count": updates_count,
"message": f"Account streaming completed after {duration_seconds} seconds."
}
@mcp.tool()
async def ibkr_get_pnl(
ctx: Context[ServerSession, AppContext],
account: str,
modelCode: str = ''
) -> Dict[str, Any]:
"""Get overall Profit and Loss."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_pnl(account, modelCode)
@mcp.tool()
async def ibkr_get_pnl_single(
ctx: Context[ServerSession, AppContext],
account: str,
modelCode: str = '',
conId: int = 0
) -> Dict[str, Any]:
"""Get PnL for a single account/model."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_pnl_single(account, modelCode, conId)
# --- Order Management Tools ---
@mcp.tool()
async def ibkr_place_order(
ctx: Context[ServerSession, AppContext],
symbol: str,
action: str,
totalQuantity: int,
orderType: str = "MKT",
lmtPrice: Optional[float] = None,
secType: str = "STK",
exchange: str = "SMART",
currency: str = "USD"
) -> Dict[str, Any]:
"""Place an order."""
tws = ctx.request_context.lifespan_context.tws
order_req = OrderRequest(
contract=ContractRequest(
symbol=symbol,
secType=secType,
exchange=exchange,
currency=currency
),
action=action,
totalQuantity=totalQuantity,
orderType=orderType,
lmtPrice=lmtPrice
)
return await tws.place_order(order_req)
@mcp.tool()
async def ibkr_cancel_order(
ctx: Context[ServerSession, AppContext],
orderId: int
) -> Dict[str, Any]:
"""Cancel an order by ID."""
tws = ctx.request_context.lifespan_context.tws
return await tws.cancel_order(orderId)
@mcp.tool()
async def ibkr_get_open_orders(
ctx: Context[ServerSession, AppContext]
) -> List[Dict[str, Any]]:
"""Get all open orders."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_open_orders()
@mcp.tool()
async def ibkr_get_executions(
ctx: Context[ServerSession, AppContext]
) -> List[Dict[str, Any]]:
"""Get all executions."""
tws = ctx.request_context.lifespan_context.tws
return await tws.get_executions()
# --- Starlette App Setup ---
# Get the MCP ASGI app
# Note: The MCP HTTP app expects to be run via run_streamable_http_async() or similar
# For now, we'll use the SSE app which is simpler for HTTP deployment
app = mcp.sse_app()
# The SSE app exposes endpoints at the root, so we'll mount it with the API prefix using Starlette
# Create Starlette wrapper app with CORS
from starlette.routing import Mount
wrapped_app = Starlette(
routes=[
Mount(os.getenv("API_PREFIX", "/api/v1"), app=app),
]
)
# Add CORS middleware
app = CORSMiddleware(
wrapped_app,
allow_origins=["*"],
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
# Main entry point for uvicorn
if __name__ == "__main__":
import uvicorn
host = os.getenv("SERVER_HOST", "0.0.0.0")
port = int(os.getenv("SERVER_PORT", 8000))
uvicorn.run(app, host=host, port=port)