Schwab Model Context Protocol Server
by jkoelker
Verified
#
import inspect
import json
import functools
from dataclasses import dataclass, field
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
TypeVar,
get_type_hints,
get_origin,
Annotated,
)
import httpx
import mcp.types as types
from authlib.integrations.base_client import OAuthError
from mcp.shared.exceptions import McpError
from schwab.client import AsyncClient
# Type variable for the decorated async function
T = TypeVar("T")
def responsify(
data: Any,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if isinstance(
data, (types.TextContent, types.ImageContent, types.EmbeddedResource)
):
return [data]
if isinstance(data, str):
return [types.TextContent(type="text", text=data)]
if isinstance(data, list):
return [responsify(item) for item in data]
raise ValueError(f"Invalid response type: {type(data)}")
def get_schema_for_type(type_obj: Any, description: str = "") -> dict:
"""Convert a Python type to a JSON schema object"""
param_schema = {"type": "string"} # Default
if type_obj is str:
param_schema = {"type": "string"}
elif type_obj is int:
param_schema = {"type": "integer"}
elif type_obj is float:
param_schema = {"type": "number"}
elif type_obj is bool:
param_schema = {"type": "boolean"}
elif get_origin(type_obj) is list:
param_schema = {"type": "array", "items": {"type": "string"}}
# Add description if available
if description:
param_schema["description"] = description
return param_schema
class SchwabtoolError(McpError):
"""Custom error class for Schwab MCP Tools"""
def __str__(self):
"""Custom string representation to include error details"""
# Include the error data in the string representation
data = getattr(self, "error", None)
if data:
try:
# Format the error details as JSON
return f"{self.error.message} - {json.dumps(self.error.data, indent=2)}"
except Exception:
# Fallback if JSON serialization fails
return f"{self.error.message} - {self.error.data}"
return super().__str__()
@classmethod
def auth_error(cls, original_error=None):
"""Create an authentication error response"""
msg = "Authentication failed. Please run 'schwab-mcp auth' to re-authenticate."
details = {"original_error": str(original_error)} if original_error else {}
return cls(types.ErrorData(code=401, message=msg, data=details))
@classmethod
def api_error(cls, original_error=None, status_code=None):
"""Create an API error response"""
msg = "Schwab API error"
details = {"original_error": str(original_error)} if original_error else {}
if status_code:
details["status_code"] = status_code
# Create an ErrorData object with the details in the data field
error_data = types.ErrorData(code=500, message=msg, data=details)
return cls(error_data)
@classmethod
def validation_error(cls, message, details=None):
"""Create a validation error response"""
return cls(types.ErrorData(code=400, message=message, data=details or {}))
class BaseSchwabTool:
"""Base class for Schwab API tools"""
def __init__(self, client: AsyncClient):
self.client = client
def definition(self) -> types.Tool:
"""Return the tool definition - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement definition()")
async def execute(
self, arguments: dict[str, Any]
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Execute the tool with the given arguments"""
raise NotImplementedError("Subclasses must implement execute()")
@staticmethod
def handle_api_errors(
func: Callable[..., Awaitable[T]],
) -> Callable[..., Awaitable[T]]:
"""Decorator to handle API errors and convert them to McpErrors"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except OAuthError as e:
# Handle authentication errors
if "refresh_token_authentication_error" in str(e):
raise SchwabtoolError.auth_error(e)
raise SchwabtoolError.api_error(e)
except httpx.HTTPStatusError as e:
# Handle HTTP errors
status_code = e.response.status_code
error_details = None
try:
# Try to parse response body for more details
error_details = e.response.json()
except Exception:
pass
details = {"status_code": status_code, "original_error": str(e)}
if error_details:
details["error_details"] = error_details
# Create a more detailed error response
error_data = types.ErrorData(
code=500, message="Schwab API error", data=details
)
raise SchwabtoolError(error_data)
except ValueError as e:
# Handle validation errors
raise SchwabtoolError.validation_error(str(e))
except Exception as e:
# Handle all other errors
raise SchwabtoolError.api_error(e)
return wrapper
class _FunctionTool(BaseSchwabTool):
"""Tool implementation that wraps a function"""
def __init__(self, client: AsyncClient, func: Callable):
super().__init__(client)
self.func = func
self._client_param = None
self._definition = self._get_definition()
def _get_definition(self) -> types.Tool:
"""Create a tool definition from the function's signature"""
sig = inspect.signature(self.func)
type_hints = get_type_hints(self.func, include_extras=True)
doc = inspect.getdoc(self.func) or ""
# Get function name
name = self.func.__name__
# Create properties for the input schema
properties = {}
required = []
# Process each parameter
for param_name, param in sig.parameters.items():
# Skip client parameter
if param.annotation is AsyncClient:
self._client_param = param_name
continue
if param_name == "client" and param.annotation == inspect.Parameter.empty:
self._client_param = param_name
continue
# Track required parameters
if param.default == inspect.Parameter.empty:
required.append(param_name)
# Process parameter type
if param_name in type_hints:
param_type = type_hints[param_name]
description = ""
# Handle Annotated types
if get_origin(param_type) is Annotated:
description = " ".join(param_type.__metadata__)
param_type = param_type.__origin__
properties[param_name] = get_schema_for_type(param_type, description)
# Create the input schema
input_schema = {
"type": "object",
"properties": properties,
}
if required:
input_schema["required"] = required
return types.Tool(
name=name,
description=doc,
inputSchema=input_schema,
)
def definition(self) -> types.Tool:
"""Return the cached tool definition"""
return self._definition
@BaseSchwabTool.handle_api_errors
async def execute(
self, arguments: dict[str, Any]
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Execute the wrapped function"""
if self._client_param:
arguments[self._client_param] = self.client
return responsify(await self.func(**arguments))
def FunctionTool(func: Callable) -> Callable[AsyncClient, _FunctionTool]:
"""Factory function for creating a FunctionBasedTool"""
return functools.partial(_FunctionTool, func=func)
@dataclass
class Registry:
"""Registry of available tools with auto-discovery"""
client: AsyncClient
write: bool = False
_tools: List[types.Tool] = field(default_factory=list)
_instances: Dict[str, BaseSchwabTool] = field(default_factory=dict)
@classmethod
def register(cls, tool: Callable[Any, Any] | BaseSchwabTool = None, **kwargs):
"""Class decorator to register a tool class or function for auto-discovery"""
if not hasattr(cls, "_registered_tools"):
cls._registered_tools = []
def _register(tool: Callable[Any, Any] | BaseSchwabTool, write: bool = False):
if inspect.isfunction(tool):
wrapped = FunctionTool(tool)
else:
wrapped = tool
wrapped.__write = write
cls._registered_tools.append(wrapped)
return tool
if tool is not None and len(kwargs) == 0:
return _register(tool, write=False)
def _decorator(func: Callable[Callable[Any, Any] | BaseSchwabTool, Any]):
return _register(func, **kwargs)
return _decorator
def __post_init__(self):
"""Initialize the registry by discovering and registering tools"""
for tool in getattr(Registry, "_registered_tools", []):
if getattr(tool, "__write", False) and not self.write:
continue
instance = tool(self.client)
if not isinstance(instance, BaseSchwabTool):
raise ValueError("Invalid tool class")
definition = instance.definition()
self._tools.append(definition)
self._instances[definition.name] = instance
if len(self._tools) == 0:
raise ValueError("No tools registered")
def get_tools(self) -> list[types.Tool]:
"""Get all registered tools"""
return self._tools
def get_tool(self, name: str) -> BaseSchwabTool:
"""Get a tool instance by name"""
if name not in self._instances:
raise SchwabtoolError.validation_error(f"Unknown tool: {name}")
return self._instances[name]
async def execute_tool(
self, name: str, arguments: Dict[str, Any]
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Execute a tool by name with arguments"""
tool = self.get_tool(name)
return await tool.execute(arguments)
# Decorator for registering tool classes or functions
register = Registry.register