mcp_server.py•15.9 kB
"""
MCP Server - Main FastAPI application for medical conversational AI.
"""
import os
import json
import asyncio
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, BackgroundTasks, Depends
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from config import DEFAULT_OPENAI_MODEL
from logger.logger_setup import get_logger
from server.health import router as health_router
from server.constants import (
    SESSION_BUFFER_MAX_SIZE,
    DEFAULT_MAX_TOKENS,
    DEFAULT_TEMPERATURE,
    DEFAULT_SYSTEM_PROMPT,
    STREAM_TIMEOUT_SECONDS,
    STREAM_POLL_INTERVAL,
    STATUS_PENDING,
    STATUS_RUNNING,
    STATUS_FINISHED,
    STATUS_ERROR,
    STATUS_CANCELLED,
    STATUS_STARTED,
    EVENT_TYPE_PARTIAL,
    EVENT_TYPE_FINAL,
    EVENT_TYPE_ERROR,
    EVENT_TYPE_CANCELLED,
    SESSION_EVENT_TOOL_STARTED,
    SESSION_EVENT_TOOL_FINISHED,
    SESSION_EVENT_TOOL_ERROR,
    SESSION_EVENT_TOOL_CANCELLED,
)
from server.exceptions import (
    ToolNotFoundError,
    CallNotFoundError,
    MCPException
)
from server.services import CallRegistryService, OpenAIService
from server.dependencies import verify_auth, get_openai_service
logger = get_logger("mcp_server")
# FastAPI app initialization
app = FastAPI(
    title="MedX MCP Server",
    description="AI-powered clinical agentic platform featuring our MedX-powered AI Agents and HealthOS, delivering advanced diagnostic support and personalized healthcare.",
    version="0.1.0"
)
# Include health check router
app.include_router(health_router)
# Add request logging middleware
from server.middleware import RequestLoggingMiddleware
app.add_middleware(RequestLoggingMiddleware)
# Add exception handlers
from fastapi.responses import JSONResponse as FastAPIJSONResponse
from fastapi import Request
@app.exception_handler(MCPException)
async def mcp_exception_handler(request: Request, exc: MCPException):
    """Handle custom MCP exceptions."""
    return FastAPIJSONResponse(
        status_code=exc.status_code,
        content={
            "error": {
                "message": exc.detail,
                "error_code": exc.error_code,
                "type": type(exc).__name__
            }
        },
        headers=exc.headers
    )
# Global state (in-memory for POC; would use Redis/database in production)
CALLS: Dict[str, Dict[str, Any]] = {}
EVENT_QUEUES: Dict[str, asyncio.Queue] = {}
REQUEST_ID_MAP: Dict[str, Dict[str, Any]] = {}
SESSION_BUFFERS: Dict[str, list] = {}
TASKS: Dict[str, asyncio.Task] = {}
# Initialize services
_call_registry_service = CallRegistryService(CALLS, EVENT_QUEUES, REQUEST_ID_MAP, TASKS)
def _append_session_event(session_id: Optional[str], event: Dict[str, Any]) -> None:
    """Append event to session buffer, keeping only last N events."""
    if not session_id:
        return
    buf = SESSION_BUFFERS.get(session_id)
    if buf is None:
        buf = []
        SESSION_BUFFERS[session_id] = buf
    buf.append(event)
    # Keep only last N events
    if len(buf) > SESSION_BUFFER_MAX_SIZE:
        del buf[:-SESSION_BUFFER_MAX_SIZE]
class ExecuteRequest(BaseModel):
    """Request model for tool execution.
    Backwards compatible: supports either legacy shape with `tool` and `input`,
    or simplified shape providing `messages` at top-level.
    """
    # Legacy fields (ignored for tool selection; input remains supported)
    tool: Optional[str] = None
    input: Optional[Dict[str, Any]] = None
    # Simplified fields
    messages: Optional[List[Dict[str, Any]]] = None
    session_id: Optional[str] = None
    request_id: Optional[str] = None  # For idempotency
    metadata: Dict[str, Any] = {}
@app.get("/mcp/manifest")
async def manifest(token: str = Depends(verify_auth)):
    """
    Get MCP server manifest (available tools and role).
    
    The manifest includes:
    - role: The server's primary role/purpose
    - description: Detailed description
    - capabilities: List of capabilities
    - tools: Available tools
    
    Returns:
        JSONResponse: Server manifest with role, description, capabilities, and tools
    """
    logger.info("Manifest requested")
    # manifest.json is in project root, not in server/
    manifest_path = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), "manifest.json"
    )
    with open(manifest_path, "r", encoding="utf-8") as f:
        manifest_data = json.load(f)
    return JSONResponse(manifest_data)
@app.post("/mcp/execute")
async def execute(
    req: ExecuteRequest,
    background_tasks: BackgroundTasks,
    token: str = Depends(verify_auth),
    ai_service: OpenAIService = Depends(get_openai_service)
):
    """
    Execute a tool asynchronously.
    
    Returns:
        JSONResponse: Call ID and status
    """
    logger.info(
        "Execute requested: tool=%s session_id=%s request_id=%s",
        req.tool,
        req.session_id,
        req.request_id
    )
    # Idempotency check
    if req.request_id:
        existing = _call_registry_service.get_existing_call(req.request_id)
        if existing:
            logger.info(
                "Idempotent execute: returning existing call_id=%s status=%s",
                existing["call_id"],
                existing["status"]
            )
            return JSONResponse({
                "call_id": existing["call_id"],
                "status": existing["status"]
            })
    # Force tool and model to server defaults
    forced_tool = "openai_chat"
    # Build input payload from either legacy `input` or simplified `messages`
    if req.input and isinstance(req.input, dict):
        input_payload = dict(req.input)
    else:
        input_payload = {"messages": (req.messages or [])}
    # Strip user-provided model if present; model is enforced later
    if isinstance(input_payload, dict) and "model" in input_payload:
        input_payload.pop("model", None)
    # Create new call
    call_id = _call_registry_service.create_call(
        forced_tool,
        input_payload,
        req.session_id,
        req.request_id
    )
    EVENT_QUEUES[call_id] = asyncio.Queue()
    
    # Register request_id for idempotency
    if req.request_id:
        _call_registry_service.register_request_id(req.request_id, call_id)
    # Kick off tool execution as an asyncio task
    task = asyncio.create_task(
        run_tool_call(call_id, forced_tool, input_payload, req.session_id, req.metadata, ai_service)
    )
    TASKS[call_id] = task
    
    logger.info("Execute started: call_id=%s task=%s", call_id, id(task))
    return JSONResponse({"call_id": call_id, "status": STATUS_STARTED})
async def run_tool_call(
    call_id: str,
    tool: str,
    input_data: Dict[str, Any],
    session_id: Optional[str],
    metadata: Dict[str, Any],
    ai_service: OpenAIService
):
    """
    Background runner: calls tool, writes partial results to queue.
    
    Args:
        call_id: Unique call identifier
        tool: Tool name to execute
        input_data: Tool input parameters
        session_id: Optional session identifier
        metadata: Optional metadata
        ai_service: OpenAI service instance
    """
    try:
        CALLS[call_id]["status"] = STATUS_RUNNING
        logger.info("Tool running: call_id=%s tool=%s", call_id, tool)
        
        # Record start in session buffer
        _append_session_event(session_id, {
            "call_id": call_id,
            "event": SESSION_EVENT_TOOL_STARTED,
            "tool": tool,
            "input": input_data,
        })
        if tool == "openai_chat":
            # Extract parameters (force model to server default)
            model = DEFAULT_OPENAI_MODEL
            messages = input_data.get("messages", [])
            # Inject default system prompt if not provided
            has_system = any(
                isinstance(m, dict) and m.get("role") == "system" for m in messages
            )
            if not has_system:
                messages = [{
                    "role": "system",
                    "content": DEFAULT_SYSTEM_PROMPT,
                }] + messages
            max_tokens = input_data.get("max_tokens", DEFAULT_MAX_TOKENS)
            # Stream from OpenAI
            partial_text = ""
            token_count = 0
            
            # OpenAI stream is synchronous, but works in async context
            for token in ai_service.stream_chat_completion(
                messages=messages,
                model=model,
                max_tokens=max_tokens,
                temperature=DEFAULT_TEMPERATURE
            ):
                partial_text += token
                await EVENT_QUEUES[call_id].put({
                    "type": EVENT_TYPE_PARTIAL,
                    "text": token
                })
                token_count += 1
            
            # After stream completes, register final
            await EVENT_QUEUES[call_id].put({
                "type": EVENT_TYPE_FINAL,
                "text": partial_text
            })
            CALLS[call_id]["status"] = STATUS_FINISHED
            CALLS[call_id]["result"] = partial_text
            
            logger.info(
                "Tool finished: call_id=%s tool=%s tokens=%s chars=%s",
                call_id,
                tool,
                token_count,
                len(partial_text)
            )
            # Persist final to session buffer
            _append_session_event(session_id, {
                "call_id": call_id,
                "event": SESSION_EVENT_TOOL_FINISHED,
                "tool": tool,
                "output": partial_text,
            })
            # Mark idempotency mapping completed
            for req_id, mapping in list(REQUEST_ID_MAP.items()):
                if mapping.get("call_id") == call_id:
                    mapping["status"] = STATUS_FINISHED
        else:
            raise ToolNotFoundError(tool)
            
    except asyncio.CancelledError:
        # Handle cooperative cancellation
        CALLS[call_id]["status"] = STATUS_CANCELLED
        queue = EVENT_QUEUES.get(call_id)
        if queue is not None:
            await queue.put({
                "type": EVENT_TYPE_CANCELLED,
                "message": "cancelled"
            })
        _append_session_event(session_id, {
            "call_id": call_id,
            "event": SESSION_EVENT_TOOL_CANCELLED,
            "message": "cancelled",
        })
        logger.info("Tool task cancelled: call_id=%s", call_id)
        raise
        
    except ToolNotFoundError:
        queue = EVENT_QUEUES.get(call_id)
        if queue is not None:
            await queue.put({
                "type": EVENT_TYPE_ERROR,
                "message": f"Unknown tool: {tool}"
            })
        CALLS[call_id]["status"] = STATUS_ERROR
        logger.warning("Unknown tool requested: tool=%s call_id=%s", tool, call_id)
        raise
        
    except Exception as e:
        CALLS[call_id]["status"] = STATUS_ERROR
        CALLS[call_id]["error"] = str(e)
        
        queue = EVENT_QUEUES.get(call_id)
        if queue is not None:
            await queue.put({
                "type": EVENT_TYPE_ERROR,
                "message": str(e)
            })
        logger.exception("Tool execution failed: call_id=%s", call_id)
        
        # Record error in session buffer
        _append_session_event(session_id, {
            "call_id": call_id,
            "event": SESSION_EVENT_TOOL_ERROR,
            "tool": tool,
            "error": str(e),
        })
        
    finally:
        # Cleanup finished/cancelled task from registry
        TASKS.pop(call_id, None)
@app.get("/mcp/stream/{call_id}")
async def stream(call_id: str, token: str = Depends(verify_auth)):
    """
    Stream events for a call via Server-Sent Events (SSE).
    
    Args:
        call_id: Call identifier to stream
        
    Returns:
        EventSourceResponse: SSE stream of events
    """
    queue = EVENT_QUEUES.get(call_id)
    if queue is None:
        raise CallNotFoundError(call_id)
    
    async def event_generator():
        """Generate SSE events from queue."""
        elapsed = 0.0
        logger.info("Stream opened: call_id=%s", call_id)
        
        while True:
            try:
                data = await asyncio.wait_for(
                    queue.get(),
                    timeout=STREAM_POLL_INTERVAL
                )
                yield {
                    "event": data.get("type", "message"),
                    "data": json.dumps(data)
                }
                if data.get("type") in (EVENT_TYPE_FINAL, EVENT_TYPE_ERROR):
                    logger.info(
                        "Stream closed: call_id=%s reason=%s",
                        call_id,
                        data.get("type")
                    )
                    break
                    
            except asyncio.TimeoutError:
                elapsed += STREAM_POLL_INTERVAL
                if elapsed > STREAM_TIMEOUT_SECONDS:
                    logger.warning("Stream timeout: call_id=%s", call_id)
                    yield {
                        "event": EVENT_TYPE_ERROR,
                        "data": json.dumps({
                            "type": EVENT_TYPE_ERROR,
                            "message": "timeout"
                        })
                    }
                    break
    
    return EventSourceResponse(event_generator())
@app.post("/mcp/cancel/{call_id}")
async def cancel(call_id: str, token: str = Depends(verify_auth)):
    """
    Cancel a running call.
    
    Args:
        call_id: Call identifier to cancel
        
    Returns:
        dict: Cancellation status
    """
    if call_id not in CALLS:
        raise CallNotFoundError(call_id)
    
    CALLS[call_id]["status"] = STATUS_CANCELLED
    
    # Cancel running task if present
    task = TASKS.get(call_id)
    if task and not task.done():
        task.cancel()
    
    queue = EVENT_QUEUES.get(call_id)
    if queue is not None:
        await queue.put({
            "type": EVENT_TYPE_CANCELLED,
            "message": "cancelled by client"
        })
    
    logger.info("Cancel requested: call_id=%s", call_id)
    
    # Record cancel in session buffer
    session_id = CALLS.get(call_id, {}).get("session_id")
    _append_session_event(session_id, {
        "call_id": call_id,
        "event": SESSION_EVENT_TOOL_CANCELLED,
        "message": "cancelled by client",
    })
    
    return {"status": STATUS_CANCELLED}
@app.post("/mcp/cancel_all")
async def cancel_all(token: str = Depends(verify_auth)):
    """
    Cancel all active calls.
    
    Returns:
        dict: Cancellation status with count
    """
    cancelled = []
    
    for call_id, task in list(TASKS.items()):
        try:
            if task and not task.done():
                task.cancel()
            CALLS.setdefault(call_id, {})["status"] = STATUS_CANCELLED
            
            queue = EVENT_QUEUES.get(call_id)
            if queue is not None:
                await queue.put({
                    "type": EVENT_TYPE_CANCELLED,
                    "message": "cancelled by server"
                })
            
            session_id = CALLS.get(call_id, {}).get("session_id")
            _append_session_event(session_id, {
                "call_id": call_id,
                "event": SESSION_EVENT_TOOL_CANCELLED,
                "message": "cancelled by server",
            })
            cancelled.append(call_id)
            
        except Exception as e:
            logger.exception("Failed to cancel call_id=%s: %s", call_id, str(e))
    
    logger.info("Cancel all requested: count=%s", len(cancelled))
    return {
        "status": STATUS_CANCELLED,
        "count": len(cancelled),
        "call_ids": cancelled
    }