"""Tushare API provider for Chinese A-share market data."""
from typing import Optional
import pandas as pd
from ..util import get_env, now_iso, ProviderError
from ..schemas import Fundamentals, FinancialStatements, FinancialStatementItem
class TushareProvider:
"""Tushare API provider for Chinese A-share market data."""
def __init__(self, token: Optional[str] = None, timeout: int = 10):
"""
Initialize Tushare provider with token validation.
Args:
token: Tushare Pro API token. If None, reads from TUSHARE_TOKEN env var.
timeout: Request timeout in seconds (not used currently, for future compatibility)
Raises:
ProviderError: If token is missing or Tushare package is not installed
"""
self.token = token or get_env("TUSHARE_TOKEN")
if not self.token:
raise ProviderError(
code="MISSING_API_KEY",
message="Tushare token is required. Set TUSHARE_TOKEN environment variable.",
provider="tushare"
)
self.timeout = timeout
# Initialize Tushare API
try:
import tushare as ts
ts.set_token(self.token)
self.pro = ts.pro_api()
except ImportError:
raise ProviderError(
code="PROVIDER_ERROR",
message="Tushare package not installed. Run: pip install tushare",
provider="tushare"
)
except Exception as e:
raise ProviderError(
code="PROVIDER_ERROR",
message=f"Failed to initialize Tushare API: {str(e)}",
provider="tushare"
)
def get_quote(self, symbol: str) -> 'Quote':
"""Get real-time quote for A-share symbol."""
try:
# Try to get real-time quote. Tushare's standard get_realtime_quotes works for most.
# It returns a DataFrame.
# Symbol format for get_realtime_quotes usually works with '600519' or 'sh600519'.
# ts_code format like '600519.SH' might need conversion to '600519' for this specific generic function,
# or we can check if it accepts ts_code.
# Safer to use the code part only.
ts_code = self.yahoo_to_ts_code(symbol)
code_only = ts_code.split(".")[0]
# Use standard tushare interface for realtime (free, works generally)
import tushare as ts
df = ts.get_realtime_quotes(code_only)
if df is None or df.empty:
raise ProviderError("NO_DATA", f"No quote found for {symbol}", provider="tushare")
row = df.iloc[0]
# Parse response
# Columns: name, open, pre_close, price, high, low, bid, ask, volume, amount, date, time, code
from ..schemas import Quote
return Quote(
symbol=symbol,
name=row['name'],
market="CN", # Generic CN
exchange="SH" if ts_code.endswith(".SH") else "SZ",
currency="CNY",
price=float(row['price']),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
prev_close=float(row['pre_close']),
volume=float(row['volume']),
change=float(row['price']) - float(row['pre_close']),
change_pct=((float(row['price']) - float(row['pre_close'])) / float(row['pre_close'])) * 100 if float(row['pre_close']) > 0 else 0.0,
timestamp=f"{row['date']}T{row['time']}",
source="tushare",
source_detail={"ts_code": ts_code}
)
except Exception as e:
raise ProviderError("PROVIDER_ERROR", f"Tushare get_quote failed: {e}", provider="tushare")
def get_history(self, symbol: str, start_date: str, end_date: str, interval: str = "1d", adjust: str = "qfq") -> list['HistoryRecord']:
"""
Get historical data using ts.pro_bar (supports adjustment).
"""
ts_code = self.yahoo_to_ts_code(symbol)
# Map interval to freq
# 1d=D, 1wk=W, 1mo=M
freq_map = {"1d": "D", "1wk": "W", "1mo": "M"}
freq = freq_map.get(interval, "D")
# pro_bar requires date in YYYYMMDD
start = start_date.replace("-", "")
end = end_date.replace("-", "")
try:
import tushare as ts
ts.set_token(self.token) # Ensure global token is set for pro_bar
# pro_bar handles adj='qfq' automatically
df = ts.pro_bar(
ts_code=ts_code,
adj=adjust,
start_date=start,
end_date=end,
freq=freq,
api=self.pro
)
if df is None or df.empty:
return []
from ..schemas import HistoryRecord
records = []
# Tushare returns desc order (newest first). Convert to list.
for _, row in df.iterrows():
# trade_date is YYYYMMDD
d_str = row['trade_date']
date_fmt = f"{d_str[:4]}-{d_str[4:6]}-{d_str[6:]}"
records.append(HistoryRecord(
date=date_fmt,
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=int(float(row['vol'])) # vol is in lots usually? pro_bar doc says vol.
))
# Sort by date asc
records.sort(key=lambda x: x.date)
return records
except Exception as e:
raise ProviderError("PROVIDER_ERROR", f"Tushare get_history failed: {e}", provider="tushare")
def yahoo_to_ts_code(self, symbol: str) -> str:
"""
Convert Yahoo symbol format to Tushare ts_code format.
Examples:
600519.SS -> 600519.SH (Shanghai)
000001.SZ -> 000001.SZ (Shenzhen - unchanged)
002595.SZ -> 002595.SZ (Shenzhen - unchanged)
Args:
symbol: Yahoo Finance symbol format
Returns:
Tushare ts_code format
Raises:
ProviderError: If symbol format is invalid for CN market
"""
symbol_upper = symbol.upper()
# Shanghai: .SS -> .SH
if symbol_upper.endswith('.SS'):
return symbol_upper.replace('.SS', '.SH')
# Shenzhen: .SZ stays .SZ
elif symbol_upper.endswith('.SZ'):
return symbol_upper
# Invalid format for Tushare
else:
raise ProviderError(
code="INVALID_ARGUMENT",
message=f"Invalid symbol format for Tushare: {symbol}. Expected .SS or .SZ suffix.",
details={"symbol": symbol},
provider="tushare"
)
def get_fundamentals(self, symbol: str, period: str = "ttm") -> Fundamentals:
"""
Get fundamental metrics from Tushare fina_indicator API.
Args:
symbol: Yahoo Finance format symbol (will be converted to ts_code)
period: "ttm" or "annual" (MVP: both use latest period data)
Returns:
Fundamentals object with mapped fields
Note:
- TTM data in Tushare is approximated using latest annual/quarterly data
- Some fields may be null if not available in Tushare
- Percentage fields (ROE, ROA, margins) are converted from % to decimal
"""
try:
# Convert symbol
ts_code = self.yahoo_to_ts_code(symbol)
# Fetch fina_indicator data
# Note: Tushare returns historical periods sorted by date descending, we take latest
df = self.pro.fina_indicator(
ts_code=ts_code,
fields='ts_code,end_date,roe,roe_dt,roa,grossprofit_margin,netprofit_margin,'
'debt_to_assets,current_ratio,quick_ratio,eps'
)
if df is None or df.empty:
raise ProviderError(
code="NO_DATA",
message=f"No fundamental data found for {symbol} ({ts_code})",
details={"symbol": symbol, "ts_code": ts_code},
provider="tushare"
)
# Take the most recent period
latest = df.iloc[0]
# Helper function to safely convert percentage to decimal
def to_decimal(value):
"""Convert percentage to decimal, return None if value is None/NaN."""
if pd.isna(value):
return None
return float(value) / 100.0
# Helper function to safely get value
def safe_get(key):
"""Get value from series, return None if missing or NaN."""
val = latest.get(key)
if pd.isna(val):
return None
return float(val)
# Map Tushare fields to unified schema
fundamentals = Fundamentals(
symbol=symbol, # Keep original Yahoo symbol
currency="CNY", # Chinese A-shares are in CNY
# Valuation - not available in fina_indicator, set to None
market_cap=None,
pe=None,
pb=None,
ps=None,
dividend_yield=None,
# Profitability - available (convert % to decimal)
roe=to_decimal(latest.get('roe_dt') or latest.get('roe')), # Prefer diluted ROE
roa=to_decimal(latest.get('roa')),
gross_margin=to_decimal(latest.get('grossprofit_margin')),
operating_margin=None, # Not in fina_indicator basic fields
net_margin=to_decimal(latest.get('netprofit_margin')),
# Financial Health
debt_to_equity=None, # Would need calculation: debt_to_assets / (1 - debt_to_assets)
current_ratio=safe_get('current_ratio'),
quick_ratio=safe_get('quick_ratio'),
# Performance (TTM) - not available in fina_indicator
revenue_ttm=None,
net_income_ttm=None,
free_cash_flow_ttm=None,
updated_at=now_iso(),
source="tushare",
raw=latest.to_dict()
)
return fundamentals
except ProviderError:
raise
except Exception as e:
# Handle Tushare API errors
error_msg = str(e).lower()
# Check for permission errors (English and Chinese)
if 'permission' in error_msg or '没有权限' in error_msg or '权限不足' in error_msg:
raise ProviderError(
code="PROVIDER_ERROR",
message="Tushare API permission denied. Check your token subscription level.",
details={"symbol": symbol, "error": str(e)},
provider="tushare"
)
raise ProviderError(
code="PROVIDER_ERROR",
message=f"Fundamentals fetch failed: {str(e)}",
details={"symbol": symbol},
provider="tushare"
)
def get_financial_statements(
self,
symbol: str,
statement: str,
period: str = "annual"
) -> FinancialStatements:
"""
Get financial statements from Tushare.
Uses Tushare APIs:
- income: income (利润表)
- balance: balancesheet (资产负债表)
- cashflow: cashflow (现金流量表)
Args:
symbol: Yahoo Finance format symbol
statement: "income", "balance", or "cashflow"
period: "annual" or "quarterly"
Returns:
FinancialStatements object with unified schema
Note:
- Returns up to 8 most recent periods
- Annual reports are filtered by year-end date (1231)
- All amounts are in original currency (CNY for A-shares)
"""
try:
ts_code = self.yahoo_to_ts_code(symbol)
# Helper function to safely get value
def safe_get(row, key):
"""Get value from row, return None if missing or NaN."""
val = row.get(key)
if pd.isna(val):
return None
return float(val)
# Select API based on statement type
if statement == "income":
df = self.pro.income(
ts_code=ts_code,
fields='ts_code,end_date,revenue,operate_profit,total_profit,n_income,basic_eps'
)
elif statement == "balance":
df = self.pro.balancesheet(
ts_code=ts_code,
fields='ts_code,end_date,total_assets,total_liab,total_hldr_eqy_exc_min_int,money_cap'
)
elif statement == "cashflow":
df = self.pro.cashflow(
ts_code=ts_code,
fields='ts_code,end_date,n_cashflow_act,n_cashflow_inv_act,n_cash_flows_fnc_act'
)
else:
raise ProviderError(
code="INVALID_ARGUMENT",
message=f"Invalid statement type: {statement}. Expected 'income', 'balance', or 'cashflow'.",
details={"statement": statement},
provider="tushare"
)
if df is None or df.empty:
raise ProviderError(
code="NO_DATA",
message=f"No {statement} statement data found for {symbol}",
details={"symbol": symbol, "ts_code": ts_code, "statement": statement},
provider="tushare"
)
# Filter by period (annual vs quarterly based on date pattern)
# Tushare end_date format: YYYYMMDD
# Annual reports typically end: 1231
# Quarterly reports: 0331, 0630, 0930, 1231
if period == "annual":
# Keep only year-end reports (1231)
df = df[df['end_date'].str.endswith('1231')]
# Take latest 8 periods
df = df.head(8)
# Build items based on statement type
items = []
for _, row in df.iterrows():
if statement == "income":
item = FinancialStatementItem(
period_end=self._format_date(row['end_date']),
revenue=safe_get(row, 'revenue'),
gross_profit=None, # Not directly in Tushare income API
operating_income=safe_get(row, 'operate_profit'),
net_income=safe_get(row, 'n_income'),
eps=safe_get(row, 'basic_eps'),
raw=row.to_dict()
)
elif statement == "balance":
item = FinancialStatementItem(
period_end=self._format_date(row['end_date']),
total_assets=safe_get(row, 'total_assets'),
total_liabilities=safe_get(row, 'total_liab'),
total_equity=safe_get(row, 'total_hldr_eqy_exc_min_int'),
cash=safe_get(row, 'money_cap'),
raw=row.to_dict()
)
else: # cashflow
item = FinancialStatementItem(
period_end=self._format_date(row['end_date']),
operating_cash_flow=safe_get(row, 'n_cashflow_act'),
investing_cash_flow=safe_get(row, 'n_cashflow_inv_act'),
financing_cash_flow=safe_get(row, 'n_cash_flows_fnc_act'),
free_cash_flow=None, # Would need to calculate
raw=row.to_dict()
)
items.append(item)
return FinancialStatements(
symbol=symbol, # Keep original Yahoo symbol
statement=statement, # type: ignore
period=period, # type: ignore
currency="CNY",
items=items,
source="tushare"
)
except ProviderError:
raise
except Exception as e:
# Handle Tushare API errors
error_msg = str(e).lower()
# Check for permission errors
if 'permission' in error_msg or '没有权限' in error_msg or '权限不足' in error_msg:
raise ProviderError(
code="PROVIDER_ERROR",
message="Tushare API permission denied. Check your token subscription level.",
details={"symbol": symbol, "statement": statement, "error": str(e)},
provider="tushare"
)
raise ProviderError(
code="PROVIDER_ERROR",
message=f"Financial statements fetch failed: {str(e)}",
details={"symbol": symbol, "statement": statement},
provider="tushare"
)
def _format_date(self, tushare_date: str) -> str:
"""
Convert Tushare date format (YYYYMMDD) to YYYY-MM-DD.
Args:
tushare_date: Date in YYYYMMDD format
Returns:
Date in YYYY-MM-DD format
"""
if len(tushare_date) == 8:
return f"{tushare_date[:4]}-{tushare_date[4:6]}-{tushare_date[6:]}"
return tushare_date