mcp_server.py•16.1 kB
#!/usr/bin/env python3
"""
FastMCP Server - A Model Context Protocol server for LLM routing
"""
import asyncio
import json
import logging
from typing import Any, Dict, List, Optional, Sequence
from urllib.parse import urljoin
import httpx
from fastmcp import FastMCP
from mcp.server.models import InitializationOptions
from mcp.types import (
CallToolRequest,
CallToolResult,
GetPromptRequest,
GetPromptResult,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsRequest,
ListToolsResult,
Prompt,
PromptArgument,
PromptMessage,
ReadResourceRequest,
ReadResourceResult,
Resource,
TextContent,
Tool,
ToolInputSchema,
)
from pydantic import BaseModel, Field
from config import MCPConfig
# Configure logging
logger = logging.getLogger(__name__)
# Models for LLM requests
class Message(BaseModel):
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class LLMRequest(BaseModel):
messages: List[Message] = Field(..., description="List of messages")
model: str = Field(default="default", description="Model name")
temperature: float = Field(default=0.7, ge=0, le=2, description="Temperature for generation")
max_tokens: Optional[int] = Field(default=None, description="Maximum tokens to generate")
stream: bool = Field(default=False, description="Whether to stream the response")
# Initialize FastMCP server
mcp = FastMCP(MCPConfig.SERVER_NAME)
@mcp.list_tools()
async def handle_list_tools() -> List[Tool]:
"""
List available tools for LLM interaction.
"""
return [
Tool(
name="chat_completion",
description="Send a chat completion request to the local LLM service",
inputSchema=ToolInputSchema(
type="object",
properties={
"messages": {
"type": "array",
"description": "Array of message objects",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"enum": ["system", "user", "assistant"],
"description": "Role of the message sender"
},
"content": {
"type": "string",
"description": "Content of the message"
}
},
"required": ["role", "content"]
}
},
"model": {
"type": "string",
"description": "Model name to use",
"default": "default"
},
"temperature": {
"type": "number",
"description": "Temperature for generation (0-2)",
"minimum": 0,
"maximum": 2,
"default": 0.7
},
"max_tokens": {
"type": "integer",
"description": "Maximum tokens to generate",
"minimum": 1
},
"stream": {
"type": "boolean",
"description": "Whether to stream the response",
"default": False
}
},
required=["messages"]
)
),
Tool(
name="list_models",
description="List available models from the local LLM service",
inputSchema=ToolInputSchema(
type="object",
properties={}
)
),
Tool(
name="health_check",
description="Check the health of the local LLM service",
inputSchema=ToolInputSchema(
type="object",
properties={}
)
)
]
@mcp.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> CallToolResult:
"""
Handle tool calls for LLM operations.
"""
try:
if name == "chat_completion":
return await _handle_chat_completion(arguments)
elif name == "list_models":
return await _handle_list_models()
elif name == "health_check":
return await _handle_health_check()
else:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
logger.error(f"Error in tool call {name}: {str(e)}")
return CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(e)}")]
)
async def _handle_chat_completion(arguments: Dict[str, Any]) -> CallToolResult:
"""Handle chat completion requests."""
try:
# Validate request
request = LLMRequest(**arguments)
# Prepare payload for local LLM service
payload = {
"model": request.model,
"messages": [msg.dict() for msg in request.messages],
"temperature": request.temperature,
"stream": request.stream
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
# Make request to local LLM service
url = urljoin(MCPConfig.LOCAL_LLM_SERVICE_URL, "/v1/chat/completions")
async with httpx.AsyncClient(
timeout=MCPConfig.LLM_REQUEST_TIMEOUT,
headers=MCPConfig.get_headers()
) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
# Extract response content
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
model_used = data.get("model", "unknown")
usage = data.get("usage", {})
# Format response
result_text = f"**Model:** {model_used}\n\n**Response:**\n{content}"
if usage:
result_text += f"\n\n**Usage:**\n"
result_text += f"- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}\n"
result_text += f"- Completion tokens: {usage.get('completion_tokens', 'N/A')}\n"
result_text += f"- Total tokens: {usage.get('total_tokens', 'N/A')}"
return CallToolResult(
content=[TextContent(type="text", text=result_text)]
)
except httpx.HTTPStatusError as e:
error_msg = f"LLM service error (HTTP {e.response.status_code}): {e.response.text}"
return CallToolResult(
content=[TextContent(type="text", text=error_msg)]
)
except Exception as e:
error_msg = f"Error calling LLM service: {str(e)}"
return CallToolResult(
content=[TextContent(type="text", text=error_msg)]
)
async def _handle_list_models() -> CallToolResult:
"""List available models from the LLM service."""
try:
url = urljoin(MCPConfig.LOCAL_LLM_SERVICE_URL, "/v1/models")
async with httpx.AsyncClient(
timeout=30.0,
headers=MCPConfig.get_headers()
) as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
models = data.get("data", [])
if models:
model_list = "\n".join([f"- {model.get('id', 'unknown')}" for model in models])
result_text = f"**Available Models:**\n{model_list}"
else:
result_text = "No models available or models endpoint not supported."
return CallToolResult(
content=[TextContent(type="text", text=result_text)]
)
except Exception as e:
error_msg = f"Error listing models: {str(e)}"
return CallToolResult(
content=[TextContent(type="text", text=error_msg)]
)
async def _handle_health_check() -> CallToolResult:
"""Check health of the LLM service."""
try:
# Try multiple common health check endpoints
health_endpoints = ["/health", "/v1/health", "/status"]
async with httpx.AsyncClient(
timeout=MCPConfig.HEALTH_CHECK_TIMEOUT,
headers=MCPConfig.get_headers()
) as client:
for endpoint in health_endpoints:
try:
url = urljoin(MCPConfig.LOCAL_LLM_SERVICE_URL, endpoint)
response = await client.get(url)
if response.status_code == 200:
return CallToolResult(
content=[TextContent(
type="text",
text=f"✅ LLM service is healthy (checked {endpoint})"
)]
)
except:
continue
# If no health endpoint works, try a simple connection test
try:
response = await client.get(MCPConfig.LOCAL_LLM_SERVICE_URL)
return CallToolResult(
content=[TextContent(
type="text",
text=f"⚠️ LLM service is reachable but no standard health endpoint found"
)]
)
except:
return CallToolResult(
content=[TextContent(
type="text",
text=f"❌ LLM service is not reachable at {MCPConfig.LOCAL_LLM_SERVICE_URL}"
)]
)
except Exception as e:
return CallToolResult(
content=[TextContent(type="text", text=f"Error checking health: {str(e)}")]
)
@mcp.list_prompts()
async def handle_list_prompts() -> List[Prompt]:
"""
List available prompts for common LLM tasks.
"""
return [
Prompt(
name="chat_assistant",
description="A helpful AI assistant prompt",
arguments=[
PromptArgument(
name="task",
description="The task or question for the assistant",
required=True
)
]
),
Prompt(
name="code_review",
description="Code review and analysis prompt",
arguments=[
PromptArgument(
name="code",
description="The code to review",
required=True
),
PromptArgument(
name="language",
description="Programming language",
required=False
)
]
),
Prompt(
name="summarize",
description="Summarize text content",
arguments=[
PromptArgument(
name="text",
description="Text to summarize",
required=True
),
PromptArgument(
name="length",
description="Desired summary length (short, medium, long)",
required=False
)
]
)
]
@mcp.get_prompt()
async def handle_get_prompt(name: str, arguments: Dict[str, str]) -> GetPromptResult:
"""
Get a specific prompt with arguments filled in.
"""
if name == "chat_assistant":
task = arguments.get("task", "")
content = f"You are a helpful AI assistant. Please help with the following task:\n\n{task}"
elif name == "code_review":
code = arguments.get("code", "")
language = arguments.get("language", "")
lang_info = f" (Language: {language})" if language else ""
content = f"Please review the following code{lang_info} and provide feedback on:\n"
content += "1. Code quality and best practices\n"
content += "2. Potential bugs or issues\n"
content += "3. Suggestions for improvement\n\n"
content += f"Code to review:\n```\n{code}\n```"
elif name == "summarize":
text = arguments.get("text", "")
length = arguments.get("length", "medium")
content = f"Please provide a {length} summary of the following text:\n\n{text}"
else:
raise ValueError(f"Unknown prompt: {name}")
return GetPromptResult(
description=f"Generated prompt: {name}",
messages=[
PromptMessage(
role="user",
content=TextContent(type="text", text=content)
)
]
)
@mcp.list_resources()
async def handle_list_resources() -> List[Resource]:
"""
List available resources (configuration, status, etc.).
"""
return [
Resource(
uri="config://server",
name="Server Configuration",
description="Current server configuration and settings",
mimeType="application/json"
),
Resource(
uri="status://llm-service",
name="LLM Service Status",
description="Status and health information of the connected LLM service",
mimeType="application/json"
)
]
@mcp.read_resource()
async def handle_read_resource(uri: str) -> ReadResourceResult:
"""
Read a specific resource.
"""
if uri == "config://server":
config_data = {
"server_name": MCPConfig.SERVER_NAME,
"server_version": MCPConfig.SERVER_VERSION,
"llm_service_url": MCPConfig.LOCAL_LLM_SERVICE_URL,
"capabilities": ["tools", "prompts", "resources"]
}
return ReadResourceResult(
contents=[
TextContent(
type="text",
text=json.dumps(config_data, indent=2)
)
]
)
elif uri == "status://llm-service":
# Check LLM service status
try:
async with httpx.AsyncClient(
timeout=5.0,
headers=MCPConfig.get_headers()
) as client:
response = await client.get(f"{MCPConfig.LOCAL_LLM_SERVICE_URL}/health")
status = "healthy" if response.status_code == 200 else "unhealthy"
except:
status = "unreachable"
status_data = {
"llm_service_url": MCPConfig.LOCAL_LLM_SERVICE_URL,
"status": status,
"last_checked": "now"
}
return ReadResourceResult(
contents=[
TextContent(
type="text",
text=json.dumps(status_data, indent=2)
)
]
)
else:
raise ValueError(f"Unknown resource: {uri}")
async def main():
"""Main entry point for the MCP server."""
# Import here to avoid issues with event loop
from mcp.server.stdio import stdio_server
logger.info(f"Starting {MCPConfig.SERVER_NAME} v{MCPConfig.SERVER_VERSION}")
logger.info(f"LLM Service URL: {MCPConfig.LOCAL_LLM_SERVICE_URL}")
async with stdio_server() as (read_stream, write_stream):
await mcp.run(
read_stream,
write_stream,
InitializationOptions(
server_name=MCPConfig.SERVER_NAME,
server_version=MCPConfig.SERVER_VERSION,
capabilities=mcp.get_capabilities()
)
)
if __name__ == "__main__":
asyncio.run(main())