"""Tests for Tushare provider."""
import pytest
from unittest.mock import Mock, patch, MagicMock
import pandas as pd
import os
from stock_data_mcp.providers.tushare_provider import TushareProvider
from stock_data_mcp.util import ProviderError
from stock_data_mcp.cache import get_cache
@pytest.fixture
def mock_tushare():
"""Mock the tushare module."""
with patch.dict('sys.modules', {'tushare': MagicMock()}):
import sys
mock_ts = sys.modules['tushare']
mock_pro = MagicMock()
mock_ts.set_token = MagicMock()
mock_ts.pro_api = MagicMock(return_value=mock_pro)
yield mock_pro
@pytest.fixture
def tushare_provider(mock_tushare):
"""Create TushareProvider with mocked API."""
with patch.dict('os.environ', {'TUSHARE_TOKEN': 'test_token'}):
provider = TushareProvider(token='test_token')
provider.pro = mock_tushare
return provider
class TestSymbolConversion:
"""Test symbol conversion logic."""
def test_shanghai_conversion(self, tushare_provider):
"""Test Requirement 2: 600519.SS -> 600519.SH conversion."""
assert tushare_provider.yahoo_to_ts_code('600519.SS') == '600519.SH'
assert tushare_provider.yahoo_to_ts_code('600519.ss') == '600519.SH'
def test_shenzhen_unchanged(self, tushare_provider):
"""Test 000001.SZ remains unchanged."""
assert tushare_provider.yahoo_to_ts_code('000001.SZ') == '000001.SZ'
assert tushare_provider.yahoo_to_ts_code('002595.SZ') == '002595.SZ'
assert tushare_provider.yahoo_to_ts_code('002595.sz') == '002595.SZ'
def test_invalid_symbol(self, tushare_provider):
"""Test invalid symbol format raises error."""
with pytest.raises(ProviderError) as exc:
tushare_provider.yahoo_to_ts_code('AAPL')
assert exc.value.code == 'INVALID_ARGUMENT'
assert 'Invalid symbol format' in exc.value.message
class TestGetFundamentals:
"""Test get_fundamentals method."""
def test_get_fundamentals_002595_sz(self, tushare_provider, mock_tushare):
"""Test Requirement 1: Get fundamentals for 002595.SZ."""
# Mock Tushare response
mock_df = pd.DataFrame([{
'ts_code': '002595.SZ',
'end_date': '20231231',
'roe': 15.5,
'roe_dt': 15.2,
'roa': 8.2,
'grossprofit_margin': 35.0,
'netprofit_margin': 12.5,
'debt_to_assets': 45.0,
'current_ratio': 2.1,
'quick_ratio': 1.8,
'eps': 1.25
}])
mock_tushare.fina_indicator.return_value = mock_df
# Call method
result = tushare_provider.get_fundamentals('002595.SZ', period='ttm')
# Assertions
assert result.symbol == '002595.SZ'
assert result.source == 'tushare'
assert result.currency == 'CNY'
assert result.roe == pytest.approx(0.152) # Use diluted ROE
assert result.roa == pytest.approx(0.082)
assert result.gross_margin == pytest.approx(0.35)
assert result.net_margin == pytest.approx(0.125)
assert result.current_ratio == 2.1
assert result.quick_ratio == 1.8
# Verify unavailable fields are None
assert result.market_cap is None
assert result.pe is None
assert result.pb is None
# Verify Tushare API was called with correct ts_code
mock_tushare.fina_indicator.assert_called_once()
call_kwargs = mock_tushare.fina_indicator.call_args[1]
assert call_kwargs['ts_code'] == '002595.SZ'
def test_get_fundamentals_600519_ss_with_conversion(self, tushare_provider, mock_tushare):
"""Test Requirement 2: Get fundamentals for 600519.SS with SS->SH mapping."""
# Mock Tushare response
mock_df = pd.DataFrame([{
'ts_code': '600519.SH',
'end_date': '20231231',
'roe': 30.2,
'roe_dt': None, # Test fallback to roe
'roa': 18.5,
'grossprofit_margin': 91.2,
'netprofit_margin': 52.8,
'debt_to_assets': 20.0,
'current_ratio': 3.5,
'quick_ratio': 3.2,
'eps': 45.20
}])
mock_tushare.fina_indicator.return_value = mock_df
# Call with .SS suffix
result = tushare_provider.get_fundamentals('600519.SS', period='annual')
# Verify symbol conversion occurred
assert result.symbol == '600519.SS' # Original symbol preserved in result
assert result.roe == pytest.approx(0.302) # Fallback to roe when roe_dt is None
mock_tushare.fina_indicator.assert_called_once()
call_kwargs = mock_tushare.fina_indicator.call_args[1]
assert call_kwargs['ts_code'] == '600519.SH' # Converted to .SH for Tushare
def test_get_fundamentals_no_data(self, tushare_provider, mock_tushare):
"""Test NO_DATA error when Tushare returns empty."""
mock_tushare.fina_indicator.return_value = pd.DataFrame()
with pytest.raises(ProviderError) as exc:
tushare_provider.get_fundamentals('000001.SZ')
assert exc.value.code == 'NO_DATA'
assert '000001.SZ' in exc.value.message
def test_get_fundamentals_handles_nan_values(self, tushare_provider, mock_tushare):
"""Test handling of NaN values in Tushare response."""
# Mock with some NaN values
mock_df = pd.DataFrame([{
'ts_code': '002595.SZ',
'end_date': '20231231',
'roe': float('nan'),
'roe_dt': 15.2,
'roa': float('nan'),
'grossprofit_margin': 35.0,
'netprofit_margin': float('nan'),
'debt_to_assets': 45.0,
'current_ratio': float('nan'),
'quick_ratio': 1.8,
'eps': 1.25
}])
mock_tushare.fina_indicator.return_value = mock_df
result = tushare_provider.get_fundamentals('002595.SZ')
# NaN values should become None
assert result.roa is None
assert result.net_margin is None
assert result.current_ratio is None
# Non-NaN values should work
assert result.roe == pytest.approx(0.152)
assert result.gross_margin == pytest.approx(0.35)
class TestGetFinancialStatements:
"""Test get_financial_statements method."""
def test_get_income_statement(self, tushare_provider, mock_tushare):
"""Test getting income statement."""
mock_df = pd.DataFrame([
{
'ts_code': '002595.SZ',
'end_date': '20231231',
'revenue': 1000000000.0,
'operate_profit': 150000000.0,
'total_profit': 160000000.0,
'n_income': 120000000.0,
'basic_eps': 2.5
},
{
'ts_code': '002595.SZ',
'end_date': '20221231',
'revenue': 900000000.0,
'operate_profit': 130000000.0,
'total_profit': 140000000.0,
'n_income': 100000000.0,
'basic_eps': 2.1
}
])
mock_tushare.income.return_value = mock_df
result = tushare_provider.get_financial_statements('002595.SZ', 'income', 'annual')
assert result.symbol == '002595.SZ'
assert result.statement == 'income'
assert result.period == 'annual'
assert result.currency == 'CNY'
assert result.source == 'tushare'
assert len(result.items) == 2
# Check first item
assert result.items[0].period_end == '2023-12-31'
assert result.items[0].revenue == 1000000000.0
assert result.items[0].operating_income == 150000000.0
assert result.items[0].net_income == 120000000.0
assert result.items[0].eps == 2.5
assert result.items[0].gross_profit is None # Not available
def test_get_balance_sheet(self, tushare_provider, mock_tushare):
"""Test getting balance sheet."""
mock_df = pd.DataFrame([
{
'ts_code': '600519.SH',
'end_date': '20231231',
'total_assets': 500000000000.0,
'total_liab': 200000000000.0,
'total_hldr_eqy_exc_min_int': 300000000000.0,
'money_cap': 50000000000.0
}
])
mock_tushare.balancesheet.return_value = mock_df
result = tushare_provider.get_financial_statements('600519.SS', 'balance', 'annual')
assert result.statement == 'balance'
assert len(result.items) == 1
assert result.items[0].total_assets == 500000000000.0
assert result.items[0].total_liabilities == 200000000000.0
assert result.items[0].total_equity == 300000000000.0
assert result.items[0].cash == 50000000000.0
def test_get_cashflow_statement(self, tushare_provider, mock_tushare):
"""Test getting cash flow statement."""
mock_df = pd.DataFrame([
{
'ts_code': '002595.SZ',
'end_date': '20231231',
'n_cashflow_act': 80000000.0,
'n_cashflow_inv_act': -30000000.0,
'n_cash_flows_fnc_act': -20000000.0
}
])
mock_tushare.cashflow.return_value = mock_df
result = tushare_provider.get_financial_statements('002595.SZ', 'cashflow', 'annual')
assert result.statement == 'cashflow'
assert len(result.items) == 1
assert result.items[0].operating_cash_flow == 80000000.0
assert result.items[0].investing_cash_flow == -30000000.0
assert result.items[0].financing_cash_flow == -20000000.0
assert result.items[0].free_cash_flow is None # Would need calculation
def test_invalid_statement_type(self, tushare_provider):
"""Test invalid statement type raises error."""
with pytest.raises(ProviderError) as exc:
tushare_provider.get_financial_statements('002595.SZ', 'invalid', 'annual')
assert exc.value.code == 'INVALID_ARGUMENT'
assert 'Invalid statement type' in exc.value.message
def test_annual_period_filtering(self, tushare_provider, mock_tushare):
"""Test that annual period filters for year-end dates."""
# Mix of quarterly and annual dates
mock_df = pd.DataFrame([
{'ts_code': '002595.SZ', 'end_date': '20231231', 'revenue': 1000.0, 'operate_profit': 100.0, 'total_profit': 110.0, 'n_income': 90.0, 'basic_eps': 2.0},
{'ts_code': '002595.SZ', 'end_date': '20230930', 'revenue': 750.0, 'operate_profit': 75.0, 'total_profit': 80.0, 'n_income': 65.0, 'basic_eps': 1.5},
{'ts_code': '002595.SZ', 'end_date': '20221231', 'revenue': 900.0, 'operate_profit': 90.0, 'total_profit': 95.0, 'n_income': 80.0, 'basic_eps': 1.8},
])
mock_tushare.income.return_value = mock_df
result = tushare_provider.get_financial_statements('002595.SZ', 'income', 'annual')
# Should only include 1231 dates
assert len(result.items) == 2
assert result.items[0].period_end == '2023-12-31'
assert result.items[1].period_end == '2022-12-31'
class TestTokenValidation:
"""Test token validation and error handling."""
def test_missing_token_error(self):
"""Test Requirement 5: Token missing error."""
with patch.dict('os.environ', {}, clear=True):
# Mock tushare module
with patch.dict('sys.modules', {'tushare': MagicMock()}):
with pytest.raises(ProviderError) as exc:
TushareProvider(token=None)
assert exc.value.code == 'MISSING_API_KEY'
assert 'TUSHARE_TOKEN' in exc.value.message
def test_permission_denied_error(self, tushare_provider, mock_tushare):
"""Test permission denied error handling."""
# Mock permission denied error from Tushare (Chinese error message)
mock_tushare.fina_indicator.side_effect = Exception('没有权限访问该接口')
with pytest.raises(ProviderError) as exc:
tushare_provider.get_fundamentals('002595.SZ')
assert exc.value.code == 'PROVIDER_ERROR'
assert 'permission' in exc.value.message.lower()
def test_permission_denied_error_english(self, tushare_provider, mock_tushare):
"""Test permission denied error handling with English message."""
mock_tushare.fina_indicator.side_effect = Exception('Permission denied')
with pytest.raises(ProviderError) as exc:
tushare_provider.get_fundamentals('002595.SZ')
assert exc.value.code == 'PROVIDER_ERROR'
assert 'permission' in exc.value.message.lower()
class TestCacheIntegration:
"""Test cache integration."""
def test_cache_hit_verification(self, tushare_provider, mock_tushare):
"""Test Requirement 6: Cache hit verification."""
cache = get_cache()
cache.clear()
# Mock Tushare response
mock_df = pd.DataFrame([{
'ts_code': '002595.SZ',
'end_date': '20231231',
'roe': 15.5,
'roe_dt': 15.2,
'roa': 8.2,
'grossprofit_margin': 35.0,
'netprofit_margin': 12.5,
'debt_to_assets': 45.0,
'current_ratio': 2.1,
'quick_ratio': 1.8,
'eps': 1.25
}])
mock_tushare.fina_indicator.return_value = mock_df
# First call - should hit Tushare
result1 = tushare_provider.get_fundamentals('002595.SZ')
assert mock_tushare.fina_indicator.call_count == 1
# Cache the result
cache.set_fundamentals('002595.SZ', 'ttm', result1)
# Second call - verify cache works
cached_result = cache.get_fundamentals('002595.SZ', 'ttm')
assert cached_result is not None
assert cached_result.symbol == '002595.SZ'
assert cached_result.source == 'tushare'
# Verify Tushare was not called again
assert mock_tushare.fina_indicator.call_count == 1
class TestIntegrationWithYahoo:
"""Test that A-share history still uses Yahoo."""
@patch('stock_data_mcp.providers.yahoo_provider.yf')
def test_ashare_history_uses_yahoo(self, mock_yf):
"""Test Requirement 3: A-share history still uses Yahoo."""
from stock_data_mcp.providers.yahoo_provider import YahooProvider
# Mock Yahoo response
mock_df = pd.DataFrame({
'Open': [100.0, 101.0],
'High': [102.0, 103.0],
'Low': [99.0, 100.0],
'Close': [101.0, 102.0],
'Volume': [1000000, 1100000]
}, index=pd.DatetimeIndex(['2024-01-02', '2024-01-03']))
mock_yf.download.return_value = mock_df
# Call Yahoo provider (not Tushare)
provider = YahooProvider()
records = provider.get_history('002595.SZ', '2024-01-02', '2024-01-03')
# Verify Yahoo was called
mock_yf.download.assert_called_once()
assert len(records) == 2
assert records[0].date == '2024-01-02'
assert records[0].close == 101.0
class TestDateFormatting:
"""Test date formatting helper."""
def test_format_date(self, tushare_provider):
"""Test _format_date converts YYYYMMDD to YYYY-MM-DD."""
assert tushare_provider._format_date('20231231') == '2023-12-31'
assert tushare_provider._format_date('20240615') == '2024-06-15'
assert tushare_provider._format_date('20200101') == '2020-01-01'
def test_format_date_passthrough(self, tushare_provider):
"""Test _format_date passes through invalid formats."""
assert tushare_provider._format_date('2023-12-31') == '2023-12-31'
assert tushare_provider._format_date('invalid') == 'invalid'