mcp_server.py20.1 kB
#!/usr/bin/env python3
"""
股票分析MCP服务器
完全兼容CherryStudio的MCP协议标准
"""
import json
import sys
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import pandas as pd
import akshare as ak
# 配置日志
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
logger = logging.getLogger(__name__)
class StockAnalyzer:
"""股票分析器类"""
def __init__(self):
self.cache = {}
self.cache_timeout = 300 # 5分钟缓存
def _get_cache_key(self, func_name: str, **kwargs) -> str:
"""生成缓存键"""
return f"{func_name}_{hash(str(sorted(kwargs.items())))}"
def _is_cache_valid(self, timestamp: datetime) -> bool:
"""检查缓存是否有效"""
return (datetime.now() - timestamp).seconds < self.cache_timeout
def _get_from_cache(self, key: str) -> Optional[Any]:
"""从缓存获取数据"""
if key in self.cache:
data, timestamp = self.cache[key]
if self._is_cache_valid(timestamp):
return data
else:
del self.cache[key]
return None
def _set_cache(self, key: str, data: Any):
"""设置缓存"""
self.cache[key] = (data, datetime.now())
def get_stock_realtime_data(self, symbol: str) -> Dict[str, Any]:
"""获取股票实时行情数据"""
try:
cache_key = self._get_cache_key("realtime", symbol=symbol)
cached_data = self._get_from_cache(cache_key)
if cached_data:
return cached_data
# 获取实时数据
df = ak.stock_zh_a_spot_em()
stock_data = df[df['代码'] == symbol]
if stock_data.empty:
return {"error": f"未找到股票代码 {symbol} 的数据"}
result = {
"股票代码": symbol,
"股票名称": stock_data.iloc[0]['名称'],
"最新价": float(stock_data.iloc[0]['最新价']),
"涨跌幅": float(stock_data.iloc[0]['涨跌幅']),
"涨跌额": float(stock_data.iloc[0]['涨跌额']),
"成交量": int(stock_data.iloc[0]['成交量']),
"成交额": float(stock_data.iloc[0]['成交额']),
"振幅": float(stock_data.iloc[0]['振幅']),
"最高": float(stock_data.iloc[0]['最高']),
"最低": float(stock_data.iloc[0]['最低']),
"今开": float(stock_data.iloc[0]['今开']),
"昨收": float(stock_data.iloc[0]['昨收']),
"更新时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
self._set_cache(cache_key, result)
return result
except Exception as e:
logger.error(f"获取实时数据失败: {e}")
return {"error": f"获取实时数据失败: {str(e)}"}
def get_stock_history_data(self, symbol: str, period: str = "daily", start_date: str = "", end_date: str = "") -> Dict[str, Any]:
"""获取股票历史数据"""
try:
cache_key = self._get_cache_key("history", symbol=symbol, period=period, start_date=start_date, end_date=end_date)
cached_data = self._get_from_cache(cache_key)
if cached_data:
return cached_data
# 设置默认日期范围
if not end_date:
end_date = datetime.now().strftime("%Y%m%d")
if not start_date:
start_date = (datetime.now() - timedelta(days=90)).strftime("%Y%m%d")
# 获取历史数据
if period == "daily":
df = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date)
elif period == "weekly":
df = ak.stock_zh_a_hist(symbol=symbol, period="weekly", start_date=start_date, end_date=end_date)
elif period == "monthly":
df = ak.stock_zh_a_hist(symbol=symbol, period="monthly", start_date=start_date, end_date=end_date)
else:
return {"error": "不支持的周期类型"}
if df.empty:
return {"error": f"未找到股票代码 {symbol} 的历史数据"}
# 转换数据格式
history_data = []
for _, row in df.tail(30).iterrows(): # 限制返回最近30条记录
history_data.append({
"日期": row['日期'].strftime("%Y-%m-%d") if pd.notna(row['日期']) else "",
"开盘": float(row['开盘']) if pd.notna(row['开盘']) else 0,
"收盘": float(row['收盘']) if pd.notna(row['收盘']) else 0,
"最高": float(row['最高']) if pd.notna(row['最高']) else 0,
"最低": float(row['最低']) if pd.notna(row['最低']) else 0,
"成交量": int(row['成交量']) if pd.notna(row['成交量']) else 0,
"涨跌幅": float(row['涨跌幅']) if pd.notna(row['涨跌幅']) else 0
})
result = {
"股票代码": symbol,
"数据周期": period,
"开始日期": start_date,
"结束日期": end_date,
"数据条数": len(history_data),
"历史数据": history_data
}
self._set_cache(cache_key, result)
return result
except Exception as e:
logger.error(f"获取历史数据失败: {e}")
return {"error": f"获取历史数据失败: {str(e)}"}
def calculate_technical_indicators(self, symbol: str, indicators: List[str] = None) -> Dict[str, Any]:
"""计算技术指标"""
try:
if indicators is None:
indicators = ['ma', 'rsi']
cache_key = self._get_cache_key("indicators", symbol=symbol, indicators=str(indicators))
cached_data = self._get_from_cache(cache_key)
if cached_data:
return cached_data
# 获取历史数据用于计算指标
end_date = datetime.now().strftime("%Y%m%d")
start_date = (datetime.now() - timedelta(days=365)).strftime("%Y%m%d")
df = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date)
if df.empty:
return {"error": f"未找到股票代码 {symbol} 的数据"}
result = {"股票代码": symbol, "技术指标": {}}
# 计算移动平均线
if 'ma' in indicators:
df['MA5'] = df['收盘'].rolling(window=5).mean()
df['MA10'] = df['收盘'].rolling(window=10).mean()
df['MA20'] = df['收盘'].rolling(window=20).mean()
latest_data = df.iloc[-1]
result["技术指标"]["移动平均线"] = {
"MA5": float(latest_data['MA5']) if pd.notna(latest_data['MA5']) else None,
"MA10": float(latest_data['MA10']) if pd.notna(latest_data['MA10']) else None,
"MA20": float(latest_data['MA20']) if pd.notna(latest_data['MA20']) else None,
"当前价格": float(latest_data['收盘'])
}
# 计算RSI
if 'rsi' in indicators:
delta = df['收盘'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
result["技术指标"]["RSI"] = {
"RSI14": float(rsi.iloc[-1]) if pd.notna(rsi.iloc[-1]) else None,
"信号": "超买" if rsi.iloc[-1] > 70 else "超卖" if rsi.iloc[-1] < 30 else "正常"
}
self._set_cache(cache_key, result)
return result
except Exception as e:
logger.error(f"计算技术指标失败: {e}")
return {"error": f"计算技术指标失败: {str(e)}"}
def get_market_sentiment(self) -> Dict[str, Any]:
"""获取市场情绪分析"""
try:
cache_key = self._get_cache_key("sentiment")
cached_data = self._get_from_cache(cache_key)
if cached_data:
return cached_data
result = {"分析时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
# 市场整体情绪
try:
# 获取A股市场总体情况
df_market = ak.stock_zh_a_spot_em()
if not df_market.empty:
up_count = len(df_market[df_market['涨跌幅'] > 0])
down_count = len(df_market[df_market['涨跌幅'] < 0])
total_count = len(df_market)
result["市场情绪"] = {
"上涨股票数": up_count,
"下跌股票数": down_count,
"平盘股票数": total_count - up_count - down_count,
"上涨比例": round(up_count / total_count * 100, 2),
"下跌比例": round(down_count / total_count * 100, 2),
"市场情绪": "乐观" if up_count > down_count else "悲观" if down_count > up_count else "中性"
}
except Exception as e:
result["市场情绪"] = {"错误": f"无法获取市场整体数据: {str(e)}"}
self._set_cache(cache_key, result)
return result
except Exception as e:
logger.error(f"获取市场情绪失败: {e}")
return {"error": f"获取市场情绪失败: {str(e)}"}
def search_stock_info(self, keyword: str) -> Dict[str, Any]:
"""搜索股票信息"""
try:
cache_key = self._get_cache_key("search", keyword=keyword)
cached_data = self._get_from_cache(cache_key)
if cached_data:
return cached_data
# 获取所有A股列表
df = ak.stock_zh_a_spot_em()
# 搜索匹配的股票
matches = df[
df['名称'].str.contains(keyword, na=False) |
df['代码'].str.contains(keyword, na=False)
].head(10) # 限制返回10个结果
if matches.empty:
return {"error": f"未找到包含关键词 '{keyword}' 的股票"}
results = []
for _, row in matches.iterrows():
results.append({
"股票代码": row['代码'],
"股票名称": row['名称'],
"最新价": float(row['最新价']),
"涨跌幅": float(row['涨跌幅']),
"成交量": int(row['成交量'])
})
result = {
"搜索关键词": keyword,
"匹配数量": len(results),
"搜索结果": results
}
self._set_cache(cache_key, result)
return result
except Exception as e:
logger.error(f"搜索股票信息失败: {e}")
return {"error": f"搜索股票信息失败: {str(e)}"}
# 创建分析器实例
analyzer = StockAnalyzer()
def handle_mcp_request():
"""处理MCP请求"""
try:
# 读取标准输入
for line in sys.stdin:
try:
request = json.loads(line.strip())
response = process_mcp_request(request)
print(json.dumps(response, ensure_ascii=False))
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {"code": -32700, "message": "Parse error"},
"id": None
}
print(json.dumps(error_response))
sys.stdout.flush()
except Exception as e:
logger.error(f"处理请求时出错: {e}")
error_response = {
"jsonrpc": "2.0",
"error": {"code": -32603, "message": f"Internal error: {str(e)}"},
"id": request.get('id') if 'request' in locals() else None
}
print(json.dumps(error_response))
sys.stdout.flush()
except KeyboardInterrupt:
logger.info("MCP服务器已停止")
except Exception as e:
logger.error(f"MCP服务器错误: {e}")
def process_mcp_request(request):
"""处理单个MCP请求"""
method = request.get('method')
params = request.get('params', {})
request_id = request.get('id')
if method == 'initialize':
return {
"jsonrpc": "2.0",
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "股票分析工具",
"version": "1.0.0"
}
},
"id": request_id
}
elif method == 'tools/list':
tools = [
{
"name": "get_stock_realtime_data",
"description": "获取股票实时行情数据",
"inputSchema": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码,如 '000001'"
}
},
"required": ["symbol"]
}
},
{
"name": "get_stock_history_data",
"description": "获取股票历史数据",
"inputSchema": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码"
},
"period": {
"type": "string",
"description": "数据周期 (daily/weekly/monthly)",
"default": "daily"
},
"start_date": {
"type": "string",
"description": "开始日期 (YYYYMMDD格式)",
"default": ""
},
"end_date": {
"type": "string",
"description": "结束日期 (YYYYMMDD格式)",
"default": ""
}
},
"required": ["symbol"]
}
},
{
"name": "calculate_technical_indicators",
"description": "计算技术指标",
"inputSchema": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码"
},
"indicators": {
"type": "array",
"description": "指标列表",
"items": {"type": "string"},
"default": ["ma", "rsi"]
}
},
"required": ["symbol"]
}
},
{
"name": "get_market_sentiment",
"description": "获取市场情绪分析",
"inputSchema": {
"type": "object",
"properties": {},
"required": []
}
},
{
"name": "search_stock_info",
"description": "搜索股票信息",
"inputSchema": {
"type": "object",
"properties": {
"keyword": {
"type": "string",
"description": "搜索关键词(股票名称或代码)"
}
},
"required": ["keyword"]
}
}
]
return {
"jsonrpc": "2.0",
"result": {"tools": tools},
"id": request_id
}
elif method == 'tools/call':
tool_name = params.get('name')
arguments = params.get('arguments', {})
try:
if tool_name == 'get_stock_realtime_data':
symbol = arguments.get('symbol')
if not symbol:
raise ValueError("缺少股票代码参数")
result = analyzer.get_stock_realtime_data(symbol)
elif tool_name == 'get_stock_history_data':
symbol = arguments.get('symbol')
period = arguments.get('period', 'daily')
start_date = arguments.get('start_date', '')
end_date = arguments.get('end_date', '')
if not symbol:
raise ValueError("缺少股票代码参数")
result = analyzer.get_stock_history_data(symbol, period, start_date, end_date)
elif tool_name == 'calculate_technical_indicators':
symbol = arguments.get('symbol')
indicators = arguments.get('indicators', ['ma', 'rsi'])
if not symbol:
raise ValueError("缺少股票代码参数")
result = analyzer.calculate_technical_indicators(symbol, indicators)
elif tool_name == 'get_market_sentiment':
result = analyzer.get_market_sentiment()
elif tool_name == 'search_stock_info':
keyword = arguments.get('keyword')
if not keyword:
raise ValueError("缺少搜索关键词参数")
result = analyzer.search_stock_info(keyword)
else:
raise ValueError(f"未知的工具: {tool_name}")
return {
"jsonrpc": "2.0",
"result": {
"content": [
{
"type": "text",
"text": json.dumps(result, ensure_ascii=False, indent=2)
}
]
},
"id": request_id
}
except Exception as e:
return {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"工具执行失败: {str(e)}"
},
"id": request_id
}
else:
return {
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"未知方法: {method}"
},
"id": request_id
}
if __name__ == "__main__":
logger.info("启动股票分析MCP服务器...")
logger.info("等待MCP协议请求...")
handle_mcp_request()