"""
统一的股票代码处理工具
整合股票分类、标准化、转换等功能
"""
import re
from typing import Dict, Optional, List, Tuple
from .stock_market_classifier import get_stock_classifier, MarketType, ExchangeType
class StockSymbolProcessor:
"""股票代码处理器 - 统一处理股票代码的分类、标准化和转换"""
def __init__(self):
self.classifier = get_stock_classifier()
def process_symbol(self, symbol: str) -> Dict:
"""
全面处理股票代码,返回所有相关信息
Args:
symbol: 原始股票代码
Returns:
Dict: 包含分类、标准化后的各种格式
"""
# 基础分类
classification = self.classifier.classify_stock(symbol)
# 生成各种标准化格式
formats = self._generate_all_formats(symbol, classification)
# 数据源策略
data_sources = self._get_data_source_strategy(classification)
# 合并结果
result = {
**classification,
"formats": formats,
"market_simple_name": self.get_market_simple_name(symbol, classification),
"data_sources": data_sources,
"original": symbol,
}
return result
def _generate_all_formats(self, symbol: str, classification: Dict) -> Dict:
"""生成所有需要的代码格式"""
return {
"tushare": self.get_tushare_format(symbol, classification),
"akshare": self.get_akshare_format(symbol, classification),
"yfinance": self.get_yfinance_format(symbol, classification),
"news_api": self.get_news_api_format(symbol, classification),
"cache_key": self.get_cache_key(symbol, classification),
"display": self.get_display_format(symbol, classification),
}
def get_tushare_format(self, symbol: str, classification: Dict = None) -> str:
"""获取Tushare API格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
# A股:确保有交易所后缀
clean_code = self._extract_base_code(symbol)
if "." in symbol and symbol.count(".") == 1:
return symbol
if clean_code.isdigit() and len(clean_code) == 6:
if clean_code.startswith(("60", "68")):
return f"{clean_code}.SH"
elif clean_code.startswith(("00", "30")):
return f"{clean_code}.SZ"
elif clean_code.startswith("8"):
return f"{clean_code}.BJ"
return symbol
elif classification["is_hk"]:
# 港股:Tushare港股格式
clean_code = self._extract_base_code(symbol)
if clean_code.isdigit():
return f"{clean_code.zfill(5)}.HK"
return symbol
else:
# 美股:Tushare不支持,返回原始代码
return self._extract_base_code(symbol).upper()
def get_akshare_format(self, symbol: str, classification: Dict = None) -> str:
"""获取AKShare API格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
# A股:纯数字代码
return self._extract_base_code(symbol)
elif classification["is_hk"]:
# 港股:5位数字代码
clean_code = self._extract_base_code(symbol)
if clean_code.isdigit():
return clean_code.zfill(5)
return clean_code
else:
# 美股:去除后缀的大写代码
return self._extract_base_code(symbol).upper()
def get_yfinance_format(self, symbol: str, classification: Dict = None) -> str:
"""获取YFinance API格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
# A股:添加Yahoo Finance后缀
clean_code = self._extract_base_code(symbol)
if clean_code.startswith(("60", "68")):
return f"{clean_code}.SS" # 上交所
else:
return f"{clean_code}.SZ" # 深交所
elif classification["is_hk"]:
# 港股:添加.HK后缀
clean_code = self._extract_base_code(symbol)
if clean_code.isdigit():
# yfinance 对港股代码格式的最终规律:
# 1. 先将代码转为整数,去除所有前导零 (e.g., '00700' -> 700)
numeric_code = int(clean_code)
# 2. 根据原始代码的位数决定格式
# - 如果原始代码不足4位 (e.g., 5, 700), 则补零到4位
# - 如果原始代码是4位或5位 (e.g., 9988, 89888), 则保持原样
if numeric_code < 1000: # 适用于 1, 5, 700 等
return f"{numeric_code:04d}.HK"
else: # 适用于 9988, 89888 等
return f"{numeric_code}.HK"
return f"{clean_code.upper()}.HK"
else:
# 美股:纯代码,去除后缀
return self._extract_base_code(symbol).upper()
def get_news_api_format(self, symbol: str, classification: Dict = None) -> str:
"""获取新闻API格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
# A股新闻:纯数字代码
return self._extract_base_code(symbol)
elif classification["is_hk"]:
# 港股新闻:标准5位.HK格式
clean_code = self._extract_base_code(symbol)
if clean_code.isdigit():
return f"{clean_code.zfill(5)}.HK"
elif clean_code.upper().endswith(".HK"):
return symbol.upper()
else:
return f"{clean_code}.HK"
else:
# 美股新闻:纯代码,去除所有后缀
clean_code = self._extract_base_code(symbol).upper()
# 移除常见美股后缀
us_suffixes = [".US", ".NASDAQ", ".NYSE", ".NMS"]
for suffix in us_suffixes:
if clean_code.endswith(suffix):
clean_code = clean_code.replace(suffix, "")
break
return clean_code
def get_cache_key(self, symbol: str, classification: Dict = None) -> str:
"""获取缓存键格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
clean_code = self._extract_base_code(symbol)
if classification["is_china"]:
# A股缓存:6位数字
return clean_code
elif classification["is_hk"]:
# 港股缓存:5位数字
if clean_code.isdigit():
return clean_code.zfill(5)
return clean_code
else:
# 美股缓存:大写字母代码
return clean_code.upper()
def get_display_format(self, symbol: str, classification: Dict = None) -> str:
"""获取显示格式的代码"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
# A股显示:代码 + 交易所
clean_code = self._extract_base_code(symbol)
if clean_code.startswith(("60", "68")):
return f"{clean_code}(SH)"
elif clean_code.startswith(("00", "30")):
return f"{clean_code}(SZ)"
elif clean_code.startswith("8"):
return f"{clean_code}(BJ)"
return clean_code
elif classification["is_hk"]:
# 港股显示:5位代码.HK
clean_code = self._extract_base_code(symbol)
if clean_code.isdigit():
return f"{clean_code.zfill(5)}.HK"
return symbol
else:
# 美股显示:纯代码
return self._extract_base_code(symbol).upper()
def _extract_base_code(self, symbol: str) -> str:
"""提取基础股票代码,去除所有后缀"""
if not symbol:
return ""
# 去除常见后缀
suffixes = [
".SH",
".SZ",
".BJ",
".SS",
".XSHE",
".XSHG", # A股后缀
".HK",
".hk", # 港股后缀
".US",
".NASDAQ",
".NYSE",
".NMS", # 美股后缀
]
clean_symbol = symbol.strip().upper()
for suffix in suffixes:
if clean_symbol.endswith(suffix.upper()):
clean_symbol = clean_symbol[: -len(suffix)]
break
return clean_symbol
def _get_data_source_strategy(self, classification: Dict) -> Dict:
"""根据市场类型获取数据源策略"""
if classification["is_china"]:
return {
"fundamentals": ["tushare", "akshare"],
"market_data": ["tushare", "akshare"],
"news": ["akshare", "eastmoney", "sina"],
"priority": "tushare",
}
elif classification["is_hk"]:
return {
"fundamentals": ["tushare", "akshare", "yfinance"],
"market_data": ["tushare", "akshare", "yfinance"],
"news": ["akshare", "yfinance", "rss"],
"priority": "tushare",
}
else: # US market
return {
"fundamentals": ["yfinance", "akshare"],
"market_data": ["yfinance", "akshare"],
"news": ["yfinance", "finnhub", "alpha_vantage", "newsapi"],
"priority": "yfinance",
}
def get_market_simple_name(self, symbol: str, classification: Dict = None) -> str:
"""获取简化的市场名称"""
if classification is None:
classification = self.classifier.classify_stock(symbol)
if classification["is_china"]:
return "china"
elif classification["is_hk"]:
return "hk"
elif classification["is_us"]:
return "us"
else:
return "unknown"
def batch_process_symbols(self, symbols: List[str]) -> Dict[str, Dict]:
"""批量处理股票代码"""
results = {}
for symbol in symbols:
try:
results[symbol] = self.process_symbol(symbol)
except Exception as e:
results[symbol] = {
"error": str(e),
"original": symbol,
"formats": {},
"data_sources": {},
}
return results
def validate_symbol_format(self, symbol: str, expected_market: str = None) -> Dict:
"""验证股票代码格式"""
result = {"is_valid": False, "market": None, "errors": [], "suggestions": []}
if not symbol or not symbol.strip():
result["errors"].append("股票代码不能为空")
return result
classification = self.classifier.classify_stock(symbol)
if classification["market"] == "未知":
result["errors"].append("无法识别的股票代码格式")
result["suggestions"].append("请检查股票代码是否正确")
else:
result["is_valid"] = True
result["market"] = classification["market"]
# 如果指定了期望市场,进行验证
if expected_market:
expected_map = {"china": "is_china", "hk": "is_hk", "us": "is_us"}
if expected_market in expected_map and not classification.get(
expected_map[expected_market], False
):
result["is_valid"] = False
result["errors"].append(f"股票代码不属于{expected_market}市场")
return result
# 全局处理器实例
_processor = None
def get_symbol_processor() -> StockSymbolProcessor:
"""获取股票代码处理器实例(单例模式)"""
global _processor
if _processor is None:
_processor = StockSymbolProcessor()
return _processor
# 便利函数
def process_symbol(symbol: str) -> Dict:
"""处理股票代码的便利函数"""
return get_symbol_processor().process_symbol(symbol)
def get_tushare_format(symbol: str) -> str:
"""获取Tushare格式代码的便利函数"""
return get_symbol_processor().get_tushare_format(symbol)
def get_akshare_format(symbol: str) -> str:
"""获取AKShare格式代码的便利函数"""
return get_symbol_processor().get_akshare_format(symbol)
def get_yfinance_format(symbol: str) -> str:
"""获取YFinance格式代码的便利函数"""
return get_symbol_processor().get_yfinance_format(symbol)
def get_news_api_format(symbol: str) -> str:
"""获取新闻API格式代码的便利函数"""
return get_symbol_processor().get_news_api_format(symbol)
def get_cache_key(symbol: str) -> str:
"""获取缓存键的便利函数"""
return get_symbol_processor().get_cache_key(symbol)
def get_market_simple_name(symbol: str) -> str:
"""获取市场简化名称的便利函数"""
return get_symbol_processor().get_market_simple_name(symbol)
def get_data_source_strategy(symbol: str) -> Dict:
"""获取数据源策略的便利函数"""
processor = get_symbol_processor()
classification = processor.classifier.classify_stock(symbol)
return processor._get_data_source_strategy(classification)