"""
MCP Gateway Implementation
Gateway that wraps Lambda functions and REST APIs as MCP tools
"""
import json
import logging
from typing import Any, Dict, List, Optional
import boto3
import httpx
import protocol
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ToolDefinition:
"""Definition of a tool that can be called via MCP"""
def __init__(
self,
name: str,
tool_type: protocol.ToolType,
description: str,
input_schema: Dict[str, Any],
lambda_arn: Optional[str] = None,
rest_endpoint: Optional[str] = None,
rest_method: str = "POST",
headers: Optional[Dict[str, str]] = None
):
self.name = name
self.tool_type = tool_type
self.description = description
self.input_schema = input_schema
self.lambda_arn = lambda_arn
self.rest_endpoint = rest_endpoint
self.rest_method = rest_method
self.headers = headers or {}
class ToolRegistry:
"""Registry for managing MCP tools"""
def __init__(self):
self.tools: Dict[str, ToolDefinition] = {}
def register_tool(self, tool: ToolDefinition):
"""Register a tool in the registry"""
if tool.name in self.tools:
logger.warning(f"Tool {tool.name} already registered, overwriting")
self.tools[tool.name] = tool
logger.info(f"Registered tool: {tool.name} ({tool.tool_type.value})")
def get_tool(self, name: str) -> Optional[ToolDefinition]:
"""Get a tool by name"""
return self.tools.get(name)
def list_tools(self) -> List[ToolDefinition]:
"""List all registered tools"""
return list(self.tools.values())
def remove_tool(self, name: str):
"""Remove a tool from registry"""
if name in self.tools:
del self.tools[name]
logger.info(f"Removed tool: {name}")
class MCPGateway:
"""Main MCP Gateway class that handles protocol translation"""
def __init__(self, region: str = "us-east-1"):
self.protocol_handler = protocol.MCPProtocolHandler()
self.tool_registry = ToolRegistry()
self.lambda_client = boto3.client('lambda', region_name=region)
self.http_client = httpx.AsyncClient(timeout=30.0)
self.region = region
def register_lambda_tool(
self,
name: str,
lambda_arn: str,
description: str,
input_schema: Dict[str, Any]
):
"""
Register a Lambda function as an MCP tool
Args:
name: Tool name
lambda_arn: ARN of the Lambda function
description: Tool description
input_schema: JSON schema for tool inputs
"""
tool = ToolDefinition(
name=name,
tool_type=protocol.ToolType.LAMBDA,
description=description,
input_schema=input_schema,
lambda_arn=lambda_arn
)
self.tool_registry.register_tool(tool)
def register_rest_tool(
self,
name: str,
endpoint: str,
description: str,
input_schema: Dict[str, Any],
method: str = "POST",
headers: Optional[Dict[str, str]] = None
):
"""
Register a REST API endpoint as an MCP tool
Args:
name: Tool name
endpoint: REST API endpoint URL
description: Tool description
input_schema: JSON schema for tool inputs
method: HTTP method (GET, POST, etc.)
headers: Optional HTTP headers
"""
tool = ToolDefinition(
name=name,
tool_type=protocol.ToolType.REST,
description=description,
input_schema=input_schema,
rest_endpoint=endpoint,
rest_method=method,
headers=headers
)
self.tool_registry.register_tool(tool)
async def invoke_lambda_tool(self, tool: ToolDefinition, arguments: Dict[str, Any]) -> Any:
"""
Invoke a Lambda function tool
Args:
tool: Tool definition
arguments: Tool arguments
Returns:
Lambda function response
"""
try:
logger.info(f"Invoking Lambda: {tool.lambda_arn} with args: {arguments}")
response = self.lambda_client.invoke(
FunctionName=tool.lambda_arn,
InvocationType='RequestResponse',
Payload=json.dumps(arguments)
)
response_payload = json.loads(response['Payload'].read())
# Handle Lambda errors
if 'errorMessage' in response_payload:
raise Exception(f"Lambda error: {response_payload['errorMessage']}")
# Extract result from Lambda response
result = response_payload.get('body', response_payload)
if isinstance(result, str):
result = json.loads(result)
return result
except Exception as e:
logger.error(f"Error invoking Lambda {tool.lambda_arn}: {e}")
raise
async def invoke_rest_tool(self, tool: ToolDefinition, arguments: Dict[str, Any]) -> Any:
"""
Invoke a REST API tool
Args:
tool: Tool definition
arguments: Tool arguments
Returns:
REST API response
"""
try:
logger.info(f"Invoking REST API: {tool.rest_endpoint} with method: {tool.rest_method}")
if tool.rest_method.upper() == "GET":
response = await self.http_client.get(
tool.rest_endpoint,
params=arguments,
headers=tool.headers
)
else:
response = await self.http_client.request(
tool.rest_method,
tool.rest_endpoint,
json=arguments,
headers=tool.headers
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error invoking REST API {tool.rest_endpoint}: {e}")
raise
async def handle_request(self, request_body: Dict[str, Any], session_id: Optional[str] = None) -> Dict[str, Any]:
"""
Handle incoming MCP request
Args:
request_body: MCP request body
session_id: Optional session ID
Returns:
MCP response
"""
self.protocol_handler.session_id = session_id
try:
parsed_request = self.protocol_handler.parse_mcp_request(request_body)
method = parsed_request["method"]
params = parsed_request["params"]
request_id = parsed_request["id"]
# Handle tools/list
if method == "tools/list":
tools = [
self.protocol_handler.format_tool_definition(
name=tool.name,
description=tool.description,
input_schema=tool.input_schema
)
for tool in self.tool_registry.list_tools()
]
return self.protocol_handler.create_tool_list_response(tools, request_id)
# Handle tools/call
elif method == "tools/call":
tool_name, arguments = self.protocol_handler.extract_tool_arguments(params)
tool = self.tool_registry.get_tool(tool_name)
if not tool:
return self.protocol_handler.create_error_response(
code=-32601,
message=f"Tool '{tool_name}' not found",
request_id=request_id
)
# Invoke the appropriate backend
try:
if tool.tool_type == protocol.ToolType.LAMBDA:
result = await self.invoke_lambda_tool(tool, arguments)
elif tool.tool_type == protocol.ToolType.REST:
result = await self.invoke_rest_tool(tool, arguments)
else:
raise ValueError(f"Unknown tool type: {tool.tool_type}")
return self.protocol_handler.create_tool_call_response(result, request_id)
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
return self.protocol_handler.create_error_response(
code=-32603,
message=f"Internal error executing tool: {str(e)}",
request_id=request_id
)
# Handle initialize
elif method == "initialize":
return self.protocol_handler.create_mcp_response(
request_id=request_id,
result={
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "mcp-gateway",
"version": "1.0.0"
}
}
)
# Unknown method
else:
return self.protocol_handler.create_error_response(
code=-32601,
message=f"Method not found: {method}",
request_id=request_id
)
except Exception as e:
logger.error(f"Error handling request: {e}")
return self.protocol_handler.create_error_response(
code=-32603,
message=f"Internal error: {str(e)}",
request_id=request_id if 'request_id' in locals() else None
)