import json
import logging
import asyncio
from typing import List, Dict, Any, Optional
from flask import Flask, request, jsonify, Response, stream_template
from pydantic import ValidationError
from app.protocol.models import (
JSONRPCEnvelope, InitializeRequest, InitializeResponse,
ResourceListRequest, ResourceReadRequest, ToolCallRequest,
CompletionRequest
)
from app.protocol.enums import MCPCapability
from app.core.errors import (
MCPError, InvalidRequestError, MethodNotFoundError,
InvalidParamsError, JSONRPCErrorCode
)
from app.services.resources import resource_service
from app.services.tools import tool_registry
from app.services.llm import llm_service
from app.auth import require_auth
from app.metrics import record_tool_call, record_llm_request, record_resource_operation
logger = logging.getLogger(__name__)
class MCPServer:
"""MCP Protocol Server Implementation"""
def __init__(self):
self.capabilities = {
MCPCapability.RESOURCES: {
"subscribe": True,
"listChanged": True
},
MCPCapability.TOOLS: {
"listChanged": True
},
MCPCapability.SAMPLING: True
}
self.server_info = {
"name": "MCP Flask Server",
"version": "1.0.0"
}
async def handle_initialize(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP initialize request"""
try:
init_request = InitializeRequest(**params)
response = InitializeResponse(
capabilities=self.capabilities,
serverInfo=self.server_info
)
logger.info(f"Initialized MCP server for client: {init_request.clientInfo}")
return response.dict()
except ValidationError as e:
raise InvalidParamsError(f"Invalid initialize parameters: {e}")
async def handle_resources_list(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle resources/list request"""
try:
list_params = ResourceListRequest(**params)
base_path = params.get("base", ".")
resources = await resource_service.list_resources(base_path, list_params.cursor)
record_resource_operation("list", True)
return {
"resources": [r.dict() for r in resources]
}
except Exception as e:
record_resource_operation("list", False)
logger.error(f"Resources list error: {str(e)}")
if isinstance(e, MCPError):
raise e
raise MCPError(f"Failed to list resources: {str(e)}")
async def handle_resources_read(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle resources/read request"""
try:
read_params = ResourceReadRequest(**params)
content = await resource_service.read_resource(read_params.uri)
record_resource_operation("read", True)
return {
"contents": [content.dict()]
}
except Exception as e:
record_resource_operation("read", False)
logger.error(f"Resource read error: {str(e)}")
if isinstance(e, MCPError):
raise e
raise MCPError(f"Failed to read resource: {str(e)}")
async def handle_tools_list(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle tools/list request"""
try:
tools = tool_registry.get_tool_schemas()
return {"tools": tools}
except Exception as e:
logger.error(f"Tools list error: {str(e)}")
raise MCPError(f"Failed to list tools: {str(e)}")
async def handle_tools_call(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle tools/call request"""
try:
call_params = ToolCallRequest(**params)
result = await tool_registry.call_tool(call_params.name, call_params.arguments)
record_tool_call(call_params.name, not result.isError)
return result.dict()
except Exception as e:
record_tool_call(params.get("name", "unknown"), False)
logger.error(f"Tool call error: {str(e)}")
if isinstance(e, MCPError):
raise e
raise MCPError(f"Tool execution failed: {str(e)}")
async def handle_sampling_request(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Handle sampling/createMessage request (LLM completion)"""
try:
# Extract prompt from the sampling request format
messages = params.get("messages", [])
if not messages:
raise InvalidParamsError("No messages provided")
# Combine messages into a single prompt
prompt_parts = []
for message in messages:
content = message.get("content", {})
if isinstance(content, dict) and content.get("type") == "text":
prompt_parts.append(content.get("text", ""))
elif isinstance(content, str):
prompt_parts.append(content)
prompt = "\n".join(prompt_parts)
completion_request = CompletionRequest(
prompt=prompt,
max_tokens=params.get("maxTokens", 1000),
temperature=params.get("temperature", 0.2)
)
result = await llm_service.complete(completion_request)
record_llm_request(llm_service.model, False)
return {
"role": "assistant",
"content": {
"type": "text",
"text": result["content"]
},
"model": result["model"],
"stopReason": "endTurn"
}
except Exception as e:
record_llm_request(llm_service.model, False)
logger.error(f"Sampling request error: {str(e)}")
raise MCPError(f"LLM request failed: {str(e)}")
async def process_request(self, envelope: JSONRPCEnvelope) -> JSONRPCEnvelope:
"""Process a single JSON-RPC request"""
try:
method = envelope.method
params = envelope.params or {}
# Route to appropriate handler
if method == "initialize":
result = await self.handle_initialize(params)
elif method == "resources/list":
result = await self.handle_resources_list(params)
elif method == "resources/read":
result = await self.handle_resources_read(params)
elif method == "tools/list":
result = await self.handle_tools_list(params)
elif method == "tools/call":
result = await self.handle_tools_call(params)
elif method == "sampling/createMessage":
result = await self.handle_sampling_request(params)
else:
raise MethodNotFoundError(method)
return JSONRPCEnvelope(
jsonrpc="2.0",
id=envelope.id,
result=result
)
except MCPError as e:
return JSONRPCEnvelope(
jsonrpc="2.0",
id=envelope.id,
error=e.to_dict()
)
except Exception as e:
logger.error(f"Unexpected error processing request: {str(e)}")
return JSONRPCEnvelope(
jsonrpc="2.0",
id=envelope.id,
error={
"code": JSONRPCErrorCode.INTERNAL_ERROR.value,
"message": f"Internal server error: {str(e)}"
}
)
def register_routes(app: Flask):
"""Register HTTP routes for MCP server"""
mcp_server = MCPServer()
@app.route('/rpc', methods=['POST'])
@require_auth
def handle_rpc():
"""Handle JSON-RPC batch requests"""
try:
# Parse request body
if not request.is_json:
raise InvalidRequestError("Content-Type must be application/json")
data = request.get_json()
if data is None:
raise InvalidRequestError("Invalid JSON")
# Handle batch or single request
is_batch = isinstance(data, list)
requests = data if is_batch else [data]
async def process_batch():
tasks = []
for req_data in requests:
try:
envelope = JSONRPCEnvelope(**req_data)
tasks.append(mcp_server.process_request(envelope))
except ValidationError as e:
tasks.append(asyncio.coroutine(lambda: JSONRPCEnvelope(
jsonrpc="2.0",
id=req_data.get("id"),
error={
"code": JSONRPCErrorCode.INVALID_REQUEST.value,
"message": f"Invalid request format: {e}"
}
))())
results = await asyncio.gather(*tasks)
return results
# Run async processing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
results = loop.run_until_complete(process_batch())
finally:
loop.close()
# Convert results to dict format
response_data = [result.dict(exclude_none=True) for result in results]
return jsonify(response_data if is_batch else response_data[0])
except InvalidRequestError as e:
return jsonify({
"jsonrpc": "2.0",
"error": e.to_dict(),
"id": None
}), 400
except Exception as e:
logger.error(f"RPC handler error: {str(e)}")
return jsonify({
"jsonrpc": "2.0",
"error": {
"code": JSONRPCErrorCode.INTERNAL_ERROR.value,
"message": "Internal server error"
},
"id": None
}), 500
@app.route('/events', methods=['GET'])
def handle_events():
"""Handle Server-Sent Events for streaming"""
# Check authentication via query parameter since EventSource doesn't support headers
token = request.args.get('token')
if not token:
return Response(
f"event: error\ndata: {json.dumps({'error': 'Authentication token required as query parameter'})}\n\n",
content_type='text/event-stream'
), 401
try:
from app.auth import verify_jwt_token
verify_jwt_token(token)
except Exception as e:
return Response(
f"event: error\ndata: {json.dumps({'error': 'Authentication failed'})}\n\n",
content_type='text/event-stream'
), 401
# Get request parameters outside the generator to avoid context issues
request_id = request.args.get('id', 'stream')
prompt = request.args.get('prompt', 'Hello, how are you?')
def event_stream():
try:
# Create completion request
completion_request = CompletionRequest(
prompt=prompt,
stream=True
)
import threading
import queue
chunk_queue = queue.Queue()
def run_async():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def collect_chunks():
async for chunk in llm_service.stream_complete(completion_request):
chunk_queue.put(('chunk', chunk))
if chunk.finished:
break
chunk_queue.put(('done', None))
loop.run_until_complete(collect_chunks())
finally:
loop.close()
except Exception as e:
chunk_queue.put(('error', str(e)))
thread = threading.Thread(target=run_async)
thread.start()
while True:
try:
item_type, item_data = chunk_queue.get(timeout=30)
if item_type == 'chunk':
yield f"event: chunk\ndata: {json.dumps(item_data.dict())}\n\n"
elif item_type == 'error':
yield f"event: error\ndata: {json.dumps({'error': item_data})}\n\n"
break
elif item_type == 'done':
yield f"event: complete\ndata: {json.dumps({'id': request_id})}\n\n"
break
except queue.Empty:
yield f"event: error\ndata: {json.dumps({'error': 'Stream timeout'})}\n\n"
break
thread.join()
record_llm_request(llm_service.model, True)
except Exception as e:
logger.error(f"SSE stream error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return Response(
event_stream(),
content_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Headers': 'Content-Type,Authorization'
}
)
@app.route('/test-stream', methods=['GET'])
def test_stream():
"""Test SSE endpoint that doesn't require OpenAI"""
def event_stream():
import time
for i in range(5):
yield f"data: Message {i+1} - Testing SSE stream\n\n"
time.sleep(1)
yield f"data: Stream complete!\n\n"
return Response(
event_stream(),
content_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Access-Control-Allow-Origin': '*'
}
)
logger.info("MCP HTTP routes registered")