#!/usr/bin/env python3
"""
股票市场数据和技术指标计算 MCP 服务
结合了:
1. Go 项目的指标计算方式(序列增量计算,O(n)时间复杂度)
2. Python main.py 的 LongPort API 数据获取
3. 清晰的数据结构和返回格式
使用方式:
python stock_indicators_mcp.py
"""
import json
import logging
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime
from zoneinfo import ZoneInfo
import pandas as pd
import numpy as np
import talib
from longport.openapi import QuoteContext, Config, Period as LongportPeriod, AdjustType
from dotenv import load_dotenv
# MCP SDK
from mcp.server.models import InitializationOptions
import mcp.types as types
from mcp.server import NotificationOptions, Server
import mcp.server.stdio
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 加载环境变量
load_dotenv()
# ============================================================================
# 数据结构定义(参考 Go 项目)
# ============================================================================
@dataclass
class IntradaySeriesData:
"""日内系列数据(最近N个数据点)"""
mid_prices: List[float] # 收盘价序列
volume_values: List[float] # 成交量序列
ema20_values: List[float] # EMA20序列
ema60_values: List[float] # EMA60序列
macd_dif: List[float] # MACD DIF序列(MACD线)
macd_dea: List[float] # MACD DEA序列(信号线)
macd_hist: List[float] # MACD HIST序列(柱状图)
rsi7_values: List[float] # RSI7序列
rsi14_values: List[float] # RSI14序列
atr14_values: List[float] # ATR14序列
bb_upper: List[float] # 布林带上轨
bb_middle: List[float] # 布林带中轨
bb_lower: List[float] # 布林带下轨
@dataclass
class MarketData:
"""市场数据(参考 Go 的 market.Data 结构)"""
symbol: str
timeframe: str
current_price: float
price_change_pct: float # 价格变化百分比
current_ema20: float
current_ema60: float
current_macd_hist: float
current_rsi7: float
current_rsi14: float
current_atr14: float
intraday_series: IntradaySeriesData
timestamp: str # ISO 格式时间戳
# ============================================================================
# 技术指标计算(参考 Go 项目的序列增量计算方式)
# ============================================================================
class IndicatorCalculator:
"""技术指标计算器(优化版本,使用序列增量计算)"""
@staticmethod
def calculate_ema_sequence(close_prices: np.ndarray, period: int) -> np.ndarray:
"""
计算EMA序列(增量计算,O(n)时间复杂度)
参考 Go 的 calculateEMASequence 方法
"""
if len(close_prices) < period:
return np.array([])
# 使用 talib 计算(内部已优化)
ema = talib.EMA(close_prices, timeperiod=period)
return ema
@staticmethod
def calculate_macd_sequence(close_prices: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
计算MACD序列(DIF、DEA、HIST)
参考 Go 的 calculateMACDSequence 方法
注意:MACD柱(HIST)乘以2,与交易所显示习惯保持一致
返回:(DIF序列, DEA序列, HIST序列)
"""
if len(close_prices) < 26:
return np.array([]), np.array([]), np.array([])
# 使用 talib 计算
dif, dea, hist = talib.MACD(
close_prices,
fastperiod=12,
slowperiod=26,
signalperiod=9
)
# MACD柱乘以2(与Go项目和交易所规则保持一致)
hist = hist * 2.0
return dif, dea, hist
@staticmethod
def calculate_rsi_sequence(close_prices: np.ndarray, period: int) -> np.ndarray:
"""
计算RSI序列(增量计算)
参考 Go 的 calculateRSISequence 方法
"""
if len(close_prices) <= period:
return np.array([])
rsi = talib.RSI(close_prices, timeperiod=period)
return rsi
@staticmethod
def calculate_atr_sequence(high: np.ndarray, low: np.ndarray,
close: np.ndarray, period: int) -> np.ndarray:
"""计算ATR序列"""
if len(close) <= period:
return np.array([])
atr = talib.ATR(high, low, close, timeperiod=period)
return atr
@staticmethod
def calculate_bollinger_bands(close_prices: np.ndarray,
period: int = 20, nbdev: float = 2.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""计算布林带"""
if len(close_prices) < period:
return np.array([]), np.array([]), np.array([])
upper, middle, lower = talib.BBANDS(
close_prices,
timeperiod=period,
nbdevup=nbdev,
nbdevdn=nbdev,
matype=0
)
return upper, middle, lower
@staticmethod
def safe_get_last_n(sequence: np.ndarray, n: int = 10) -> List[float]:
"""
安全获取序列的最后N个值(处理NaN)
参考 Go 的 safeGetLastN 方法
"""
if len(sequence) == 0:
return []
# 获取最后N个值
last_n = sequence[-n:] if len(sequence) >= n else sequence
# 将NaN转换为None,便于JSON序列化
result = []
for val in last_n:
if pd.isna(val) or np.isnan(val):
result.append(None)
else:
result.append(float(val))
return result
# ============================================================================
# 数据获取(参考 main.py 的 LongPort API)
# ============================================================================
# 周期映射(参考 main.py)
PERIOD_MAPPING = {
'1m': LongportPeriod.Min_1,
'2m': LongportPeriod.Min_2,
'3m': LongportPeriod.Min_3,
'5m': LongportPeriod.Min_5,
'10m': LongportPeriod.Min_10,
'15m': LongportPeriod.Min_15,
'20m': LongportPeriod.Min_20,
'30m': LongportPeriod.Min_30,
'45m': LongportPeriod.Min_45,
'1h': LongportPeriod.Min_60,
'2h': LongportPeriod.Min_120,
'3h': LongportPeriod.Min_180,
'4h': LongportPeriod.Min_240,
'1d': LongportPeriod.Day,
'1w': LongportPeriod.Week,
'1M': LongportPeriod.Month, # 月K线
'1Q': LongportPeriod.Quarter, # 季K线
'1Y': LongportPeriod.Year, # 年K线
}
class StockDataProvider:
"""股票数据提供者(使用 LongPort API)"""
def __init__(self):
"""初始化 LongPort API 连接"""
try:
self.ctx = QuoteContext(Config.from_env())
logger.info("✅ LongPort API 连接成功")
except Exception as e:
logger.error(f"❌ LongPort API 连接失败: {e}")
self.ctx = None
def get_market_data(self, symbol: str, timeframe: str = '1h', limit: int = 200) -> Optional[MarketData]:
"""
获取指定股票和周期的市场数据
Args:
symbol: 股票代码(如 'AAPL.US')
timeframe: 时间周期('1m', '5m', '15m', '30m', '1h', '4h', '1d', '1w')
limit: 获取K线数量(默认200)
Returns:
MarketData 对象,如果失败返回 None
"""
if not self.ctx:
logger.error("LongPort API 未连接")
return None
# 验证时间周期
if timeframe not in PERIOD_MAPPING:
logger.error(f"不支持的时间周期: {timeframe}")
return None
try:
# 获取K线数据
period = PERIOD_MAPPING[timeframe]
klines = self.ctx.candlesticks(symbol, period, limit, AdjustType.NoAdjust)
if not klines:
logger.warning(f"未获取到 {symbol} 的K线数据")
return None
# 转换为DataFrame
df = self._klines_to_dataframe(klines)
if df.empty or len(df) < 20: # 至少需要20根K线
logger.warning(f"{symbol} K线数据不足(需要至少20根)")
return None
# 计算指标
calc = IndicatorCalculator()
# 提取价格数据
close = df['close'].values
high = df['high'].values
low = df['low'].values
volume = df['volume'].values
# 计算EMA序列
ema20_seq = calc.calculate_ema_sequence(close, 20)
ema60_seq = calc.calculate_ema_sequence(close, 60)
# 计算MACD序列
dif_seq, dea_seq, hist_seq = calc.calculate_macd_sequence(close)
# 计算RSI序列
rsi7_seq = calc.calculate_rsi_sequence(close, 7)
rsi14_seq = calc.calculate_rsi_sequence(close, 14)
# 计算ATR序列
atr14_seq = calc.calculate_atr_sequence(high, low, close, 14)
# 计算布林带
bb_upper, bb_middle, bb_lower = calc.calculate_bollinger_bands(close, 20, 2.0)
# 获取当前值(最新一根K线的指标)
current_price = float(close[-1])
current_ema20 = float(ema20_seq[-1]) if len(ema20_seq) > 0 else 0.0
current_ema60 = float(ema60_seq[-1]) if len(ema60_seq) > 0 else 0.0
current_macd_hist = float(hist_seq[-1]) if len(hist_seq) > 0 else 0.0
current_rsi7 = float(rsi7_seq[-1]) if len(rsi7_seq) > 0 else 0.0
current_rsi14 = float(rsi14_seq[-1]) if len(rsi14_seq) > 0 else 0.0
current_atr14 = float(atr14_seq[-1]) if len(atr14_seq) > 0 else 0.0
# 计算价格变化百分比(相对于第一根K线)
price_change_pct = ((current_price - close[0]) / close[0]) * 100 if close[0] > 0 else 0.0
# 构建日内系列数据(最近10个数据点)
intraday_series = IntradaySeriesData(
mid_prices=calc.safe_get_last_n(close, 10),
volume_values=calc.safe_get_last_n(volume, 10),
ema20_values=calc.safe_get_last_n(ema20_seq, 10),
ema60_values=calc.safe_get_last_n(ema60_seq, 10),
macd_dif=calc.safe_get_last_n(dif_seq, 10),
macd_dea=calc.safe_get_last_n(dea_seq, 10),
macd_hist=calc.safe_get_last_n(hist_seq, 10),
rsi7_values=calc.safe_get_last_n(rsi7_seq, 10),
rsi14_values=calc.safe_get_last_n(rsi14_seq, 10),
atr14_values=calc.safe_get_last_n(atr14_seq, 10),
bb_upper=calc.safe_get_last_n(bb_upper, 10),
bb_middle=calc.safe_get_last_n(bb_middle, 10),
bb_lower=calc.safe_get_last_n(bb_lower, 10),
)
# 构建市场数据对象
market_data = MarketData(
symbol=symbol,
timeframe=timeframe,
current_price=current_price,
price_change_pct=price_change_pct,
current_ema20=current_ema20,
current_ema60=current_ema60,
current_macd_hist=current_macd_hist,
current_rsi7=current_rsi7,
current_rsi14=current_rsi14,
current_atr14=current_atr14,
intraday_series=intraday_series,
timestamp=datetime.now(ZoneInfo("Asia/Shanghai")).isoformat()
)
logger.info(f"✅ 成功获取 {symbol} ({timeframe}) 市场数据")
return market_data
except Exception as e:
logger.error(f"❌ 获取 {symbol} 市场数据失败: {e}", exc_info=True)
return None
def _klines_to_dataframe(self, klines) -> pd.DataFrame:
"""将K线数据转换为DataFrame"""
data = []
for k in klines:
data.append({
'timestamp': k.timestamp,
'open': float(k.open),
'high': float(k.high),
'low': float(k.low),
'close': float(k.close),
'volume': k.volume,
})
df = pd.DataFrame(data)
# 转换时间戳为上海时区
shanghai_tz = ZoneInfo("Asia/Shanghai")
df['date'] = pd.to_datetime(df['timestamp'], unit='s').dt.tz_localize('UTC').dt.tz_convert(shanghai_tz)
df = df.sort_values('date').reset_index(drop=True)
return df
# ============================================================================
# MCP 服务器
# ============================================================================
class StockIndicatorsMCPServer:
"""股票指标 MCP 服务器"""
def __init__(self):
self.server = Server("stock-indicators")
self.data_provider = StockDataProvider()
# 注册工具
self._register_tools()
logger.info("📊 股票指标 MCP 服务器已初始化")
def _register_tools(self):
"""注册MCP工具"""
@self.server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""列出所有可用工具"""
return [
types.Tool(
name="get_stock_indicators",
description="获取指定股票和周期的技术指标数据(支持分钟/小时/日/周/月/季/年多周期)",
inputSchema={
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码(例如:AAPL.US, NVDA.US, TSLA.US)",
},
"timeframe": {
"type": "string",
"description": "时间周期:分钟(1m/2m/3m/5m/10m/15m/20m/30m/45m)、小时(1h/2h/3h/4h)、日/周/月/季/年(1d/1w/1M/1Q/1Y),默认 1h",
"default": "1h",
},
"limit": {
"type": "integer",
"description": "获取K线数量(默认200,月/季/年线建议60-120)",
"default": 200,
},
},
"required": ["symbol"],
},
),
types.Tool(
name="get_multi_timeframe_analysis",
description="获取指定股票的多时间框架分析(同时获取多个周期的数据)",
inputSchema={
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码(例如:AAPL.US)",
},
"timeframes": {
"type": "array",
"items": {"type": "string"},
"description": "时间周期列表(默认:['1M', '1w', '1d', '4h'],支持从1m到1Y的所有周期)",
"default": ["1M", "1w", "1d", "4h"],
},
},
"required": ["symbol"],
},
),
]
@self.server.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""处理工具调用"""
if name == "get_stock_indicators":
return await self._handle_get_stock_indicators(arguments or {})
elif name == "get_multi_timeframe_analysis":
return await self._handle_multi_timeframe_analysis(arguments or {})
else:
raise ValueError(f"未知工具: {name}")
async def _handle_get_stock_indicators(self, arguments: dict) -> list[types.TextContent]:
"""处理获取股票指标请求"""
symbol = arguments.get("symbol")
timeframe = arguments.get("timeframe", "1h")
limit = arguments.get("limit", 200)
if not symbol:
return [types.TextContent(
type="text",
text=json.dumps({"error": "缺少必需参数: symbol"}, ensure_ascii=False, indent=2)
)]
# 获取市场数据
market_data = self.data_provider.get_market_data(symbol, timeframe, limit)
if not market_data:
return [types.TextContent(
type="text",
text=json.dumps({"error": f"获取 {symbol} ({timeframe}) 数据失败"}, ensure_ascii=False, indent=2)
)]
# 转换为字典并序列化
result = asdict(market_data)
# 格式化输出(参考Go的Format方法)
formatted_output = self._format_market_data(market_data)
return [types.TextContent(
type="text",
text=formatted_output
)]
async def _handle_multi_timeframe_analysis(self, arguments: dict) -> list[types.TextContent]:
"""处理多时间框架分析请求"""
symbol = arguments.get("symbol")
timeframes = arguments.get("timeframes", ["1d", "4h", "1h", "15m"])
if not symbol:
return [types.TextContent(
type="text",
text=json.dumps({"error": "缺少必需参数: symbol"}, ensure_ascii=False, indent=2)
)]
# 并发获取多个时间框架的数据
results = {}
for tf in timeframes:
market_data = self.data_provider.get_market_data(symbol, tf, 200)
if market_data:
results[tf] = asdict(market_data)
else:
results[tf] = {"error": f"获取 {tf} 数据失败"}
# 格式化输出
output = f"# {symbol} 多时间框架分析\n\n"
for tf, data in results.items():
if "error" in data:
output += f"## {tf} 周期:{data['error']}\n\n"
else:
output += f"## {tf} 周期\n"
output += f"当前价格: {data['current_price']:.2f}\n"
output += f"EMA20: {data['current_ema20']:.2f}\n"
output += f"MACD柱: {data['current_macd_hist']:.3f}\n"
output += f"RSI(7): {data['current_rsi7']:.2f}\n"
output += f"ATR(14): {data['current_atr14']:.2f}\n"
output += f"价格变化: {data['price_change_pct']:.2f}%\n\n"
return [types.TextContent(
type="text",
text=output
)]
def _format_market_data(self, data: MarketData) -> str:
"""格式化市场数据输出(参考Go的Format方法)"""
output = f"# {data.symbol} ({data.timeframe}) 市场数据\n\n"
output += f"**当前价格**: {data.current_price:.2f}\n"
output += f"**EMA20**: {data.current_ema20:.2f}\n"
output += f"**EMA60**: {data.current_ema60:.2f}\n"
output += f"**MACD柱**: {data.current_macd_hist:.3f}\n"
output += f"**RSI(7)**: {data.current_rsi7:.2f}\n"
output += f"**RSI(14)**: {data.current_rsi14:.2f}\n"
output += f"**ATR(14)**: {data.current_atr14:.2f}\n"
output += f"**价格变化**: {data.price_change_pct:.2f}%\n\n"
output += "## 日内系列数据(最近10个数据点,从旧到新)\n\n"
series = data.intraday_series
def format_list(values: List[Optional[float]], precision: int = 2) -> str:
"""格式化数值列表"""
formatted = []
for v in values:
if v is None:
formatted.append("NaN")
else:
formatted.append(f"{v:.{precision}f}")
return "[" + ", ".join(formatted) + "]"
output += f"**价格**: {format_list(series.mid_prices, 2)}\n\n"
output += f"**成交量**: {format_list(series.volume_values, 0)}\n\n"
output += f"**EMA20**: {format_list(series.ema20_values, 2)}\n\n"
output += f"**EMA60**: {format_list(series.ema60_values, 2)}\n\n"
output += f"**MACD DIF**: {format_list(series.macd_dif, 3)}\n\n"
output += f"**MACD DEA**: {format_list(series.macd_dea, 3)}\n\n"
output += f"**MACD HIST**: {format_list(series.macd_hist, 3)}\n\n"
output += f"**RSI(7)**: {format_list(series.rsi7_values, 2)}\n\n"
output += f"**RSI(14)**: {format_list(series.rsi14_values, 2)}\n\n"
output += f"**ATR(14)**: {format_list(series.atr14_values, 2)}\n\n"
output += f"**布林带上轨**: {format_list(series.bb_upper, 2)}\n\n"
output += f"**布林带中轨**: {format_list(series.bb_middle, 2)}\n\n"
output += f"**布林带下轨**: {format_list(series.bb_lower, 2)}\n\n"
output += f"*数据时间: {data.timestamp}*\n"
return output
async def run(self):
"""运行MCP服务器"""
logger.info("🚀 启动股票指标 MCP 服务器...")
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await self.server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="stock-indicators",
server_version="1.0.0",
capabilities=self.server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
# ============================================================================
# 主程序
# ============================================================================
async def main():
"""主入口"""
logger.info("="*60)
logger.info("📊 股票市场数据和技术指标计算 MCP 服务")
logger.info("="*60)
# 创建并运行服务器
server = StockIndicatorsMCPServer()
await server.run()
if __name__ == "__main__":
import asyncio
asyncio.run(main())