"""HTTP transport server with REST API for MCP tool search.
This module provides:
1. Streamable HTTP transport for MCP clients
2. REST API for tool catalog management
Usage:
uvicorn src.streamable_http_server:app --host=0.0.0.0 --port=8000
"""
import logging
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from mcp.server.streamable_http import StreamableHTTPServerTransport
from .tools import mcp_server, bm25_search, regex_search, embeddings_search
from .catalog import catalog, ToolDefinition, InputSchema
from .api_key_middleware import APIKeyMiddleware
from .config import config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Pydantic models for REST API
class ToolInput(BaseModel):
"""Input model for registering a tool."""
name: str = Field(..., description="Unique name of the tool")
description: str = Field(..., description="Description of what the tool does")
input_schema: dict[str, Any] = Field(..., description="JSON Schema for tool inputs")
tags: list[str] = Field(default_factory=list, description="Tags for categorization")
class ToolResponse(BaseModel):
"""Response model for tool operations."""
name: str
description: str
input_schema: dict[str, Any]
tags: list[str]
defer_loading: bool = True
class BulkToolInput(BaseModel):
"""Input model for bulk tool registration."""
tools: list[ToolInput]
class SearchQuery(BaseModel):
"""Input model for search operations."""
query: str = Field(..., description="Search query")
top_k: int = Field(default=5, ge=1, le=20)
class RegexQuery(BaseModel):
"""Input model for regex search."""
pattern: str = Field(..., max_length=200)
top_k: int = Field(default=5, ge=1, le=20)
# Lifespan context manager
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
logger.info("Starting Tool Search MCP Server (HTTP mode)")
logger.info(f"Embedding model: {config.EMBEDDING_MODEL}")
# Load sample tools if empty
if catalog.count() == 0:
from .stdio_server import load_sample_tools
load_sample_tools()
yield
logger.info("Shutting down server")
# Create FastAPI app
app = FastAPI(
title="MCP Tool Search Server",
description="MCP server for dynamic tool discovery with REST API for tool management",
version="0.1.0",
lifespan=lifespan,
)
# Add authentication middleware
app.add_middleware(APIKeyMiddleware)
# MCP transport
transport = StreamableHTTPServerTransport("/mcp")
# ============================================================================
# Health & Info Endpoints
# ============================================================================
@app.get("/")
async def root():
"""Root endpoint with server info."""
return {
"name": "MCP Tool Search Server",
"version": "0.1.0",
"mcp_endpoint": "/mcp",
"tools_endpoint": "/tools",
"tool_count": catalog.count(),
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "tool_count": catalog.count()}
# ============================================================================
# MCP Endpoint
# ============================================================================
@app.post("/mcp")
@app.get("/mcp")
async def mcp_handler(request: Request):
"""Handle MCP protocol requests."""
return await transport.handle_request(request, mcp_server)
# ============================================================================
# Tool Management REST API
# ============================================================================
@app.get("/tools", response_model=list[ToolResponse])
async def list_tools():
"""List all tools in the catalog."""
tools = catalog.list_tools()
return [
ToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema.model_dump(),
tags=t.tags,
defer_loading=t.defer_loading,
)
for t in tools
]
@app.get("/tools/{name}", response_model=ToolResponse)
async def get_tool(name: str):
"""Get a specific tool by name."""
tool = catalog.get_tool(name)
if not tool:
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
return ToolResponse(
name=tool.name,
description=tool.description,
input_schema=tool.input_schema.model_dump(),
tags=tool.tags,
defer_loading=tool.defer_loading,
)
@app.post("/tools", response_model=ToolResponse, status_code=201)
async def register_tool(tool_input: ToolInput):
"""Register a new tool or update an existing one."""
tool = ToolDefinition(
name=tool_input.name,
description=tool_input.description,
input_schema=InputSchema(**tool_input.input_schema),
tags=tool_input.tags,
)
catalog.register_tool(tool)
return ToolResponse(
name=tool.name,
description=tool.description,
input_schema=tool.input_schema.model_dump(),
tags=tool.tags,
defer_loading=tool.defer_loading,
)
@app.post("/tools/bulk", status_code=201)
async def register_tools_bulk(bulk_input: BulkToolInput):
"""Register multiple tools at once."""
tools = [
ToolDefinition(
name=t.name,
description=t.description,
input_schema=InputSchema(**t.input_schema),
tags=t.tags,
)
for t in bulk_input.tools
]
catalog.register_tools(tools)
return {"registered": len(tools), "total": catalog.count()}
@app.delete("/tools/{name}")
async def delete_tool(name: str):
"""Remove a tool from the catalog."""
if not catalog.remove_tool(name):
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
return {"deleted": name}
@app.delete("/tools")
async def clear_tools():
"""Remove all tools from the catalog."""
count = catalog.count()
catalog.clear()
return {"deleted": count}
# ============================================================================
# Search API (alternative to MCP tools)
# ============================================================================
@app.post("/search/bm25")
async def search_bm25(query: SearchQuery):
"""Search tools using BM25 keyword matching."""
tools = bm25_search.search(query.query, query.top_k)
return {
"query": query.query,
"results": [
{"name": t.name, "description": t.description}
for t in tools
]
}
@app.post("/search/regex")
async def search_regex(query: RegexQuery):
"""Search tools using regex pattern matching."""
from .search.regex import RegexSearchError
try:
tools = regex_search.search(query.pattern, query.top_k)
return {
"pattern": query.pattern,
"results": [
{"name": t.name, "description": t.description}
for t in tools
]
}
except RegexSearchError as e:
raise HTTPException(status_code=400, detail={
"error": str(e),
"error_code": e.error_code
})
@app.post("/search/semantic")
async def search_semantic(query: SearchQuery):
"""Search tools using semantic similarity."""
results = embeddings_search.search_with_scores(query.query, query.top_k)
return {
"query": query.query,
"results": [
{"name": t.name, "description": t.description, "score": score}
for t, score in results
]
}
# ============================================================================
# Error Handlers
# ============================================================================
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler."""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)