sampling_router.py•6.69 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MCP 采样服务路由器
处理模型采样相关的 MCP 请求
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional, List, Union
from .base_router import MCPBaseRouter
logger = logging.getLogger("sampling_router")
class SamplingRouter(MCPBaseRouter):
"""采样服务路由器"""
def __init__(self):
"""初始化路由器"""
super().__init__()
self.models = {
"gpt-3.5-turbo": {
"id": "gpt-3.5-turbo",
"name": "GPT-3.5 Turbo",
"description": "OpenAI 的 GPT-3.5 Turbo 模型",
"capabilities": ["chat", "completion"]
},
"gpt-4": {
"id": "gpt-4",
"name": "GPT-4",
"description": "OpenAI 的 GPT-4 模型",
"capabilities": ["chat", "completion", "vision"]
},
"claude-2": {
"id": "claude-2",
"name": "Claude 2",
"description": "Anthropic 的 Claude 2 模型",
"capabilities": ["chat", "completion"]
}
}
self._register_methods()
def _register_methods(self):
"""注册方法处理器"""
self.register_method("sampling/list_models", self.handle_list_models)
self.register_method("sampling/get_model", self.handle_get_model)
self.register_method("sampling/generate", self.handle_generate)
self.register_method("sampling/stream", self.handle_stream)
self.register_method("sampling/cancel", self.handle_cancel)
def handle_list_models(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理列出模型请求"""
logger.info(f"处理列出模型请求,参数: {params}")
# 过滤模型(如果有能力要求)
models = list(self.models.values())
if params and "capabilities" in params:
required_capabilities = params["capabilities"]
models = [
model for model in models
if all(cap in model["capabilities"] for cap in required_capabilities)
]
return {
"models": models
}
def handle_get_model(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理获取模型信息请求"""
if not params or "id" not in params:
raise ValueError("缺少模型 ID 参数")
model_id = params["id"]
logger.info(f"处理获取模型信息请求,模型 ID: {model_id}")
if model_id not in self.models:
raise ValueError(f"模型 ID '{model_id}' 不存在")
return self.models[model_id]
async def handle_generate(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理生成请求"""
if not params:
raise ValueError("缺少请求参数")
model_id = params.get("model")
if not model_id:
raise ValueError("缺少模型 ID 参数")
if model_id not in self.models:
raise ValueError(f"模型 ID '{model_id}' 不存在")
prompt = params.get("prompt")
if not prompt:
raise ValueError("缺少提示词参数")
logger.info(f"处理生成请求,模型: {model_id}, 提示词长度: {len(str(prompt))}")
# 模拟生成过程
await asyncio.sleep(1)
return {
"id": "gen_" + model_id.replace("-", "_"),
"model": model_id,
"output": f"这是来自 {model_id} 模型的示例输出,基于您的提示词。",
"usage": {
"prompt_tokens": len(str(prompt)) // 4,
"completion_tokens": 20,
"total_tokens": len(str(prompt)) // 4 + 20
}
}
async def handle_stream(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理流式生成请求"""
if not params:
raise ValueError("缺少请求参数")
model_id = params.get("model")
if not model_id:
raise ValueError("缺少模型 ID 参数")
if model_id not in self.models:
raise ValueError(f"模型 ID '{model_id}' 不存在")
prompt = params.get("prompt")
if not prompt:
raise ValueError("缺少提示词参数")
logger.info(f"处理流式生成请求,模型: {model_id}, 提示词长度: {len(str(prompt))}")
# 在实际实现中,这里应该返回一个流式响应的标识符
# 客户端可以通过 SSE 连接接收流式输出
stream_id = f"stream_{model_id.replace('-', '_')}_{asyncio.get_event_loop().time()}"
# 启动后台任务处理流式生成
asyncio.create_task(self._process_stream(stream_id, model_id, prompt))
return {
"stream_id": stream_id,
"model": model_id,
"status": "started"
}
async def _process_stream(self, stream_id: str, model_id: str, prompt: Any):
"""处理流式生成的后台任务"""
# 在实际实现中,这里应该调用模型 API 并通过 SSE 发送结果
# 这里只是模拟流式输出
logger.info(f"开始流式生成,ID: {stream_id}")
# 模拟流式输出
chunks = [
"这是", "来自", f" {model_id} ", "模型的", "示例", "流式", "输出,",
"基于", "您的", "提示词。"
]
for chunk in chunks:
# 发送通知(在实际实现中,这应该通过 SSE 发送)
# 这里只是记录日志
logger.info(f"流 {stream_id} 输出: {chunk}")
await asyncio.sleep(0.3)
# 发送完成通知
logger.info(f"流 {stream_id} 完成")
def handle_cancel(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理取消生成请求"""
if not params or "id" not in params:
raise ValueError("缺少生成 ID 参数")
gen_id = params["id"]
logger.info(f"处理取消生成请求,生成 ID: {gen_id}")
# 在实际实现中,这里应该取消正在进行的生成任务
# 这里只是返回成功响应
return {
"id": gen_id,
"status": "cancelled"
}