stock_service.py•31.9 kB
import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
import twstock
from twstock import BestFourPoint, Stock
from tw_stock_agent.services.cache_service import CacheConfig, CacheService
from tw_stock_agent.utils.connection_pool import HTTPConnectionPool, get_global_pool
from tw_stock_agent.utils.performance_monitor import get_global_monitor
from tw_stock_agent.exceptions import (
StockNotFoundError,
StockDataUnavailableError,
InvalidStockCodeError,
ExternalAPIError,
CacheError,
ErrorCode,
ErrorSeverity,
TwStockAgentError
)
from tw_stock_agent.utils.validation import StockCodeValidator
from tw_stock_agent.utils.error_handler import (
with_async_error_handling,
with_retry,
CircuitBreaker,
ErrorEnricher
)
logger = logging.getLogger("tw-stock-agent.stock_service")
# 更新股票代碼列表
try:
#twstock.__update_codes()
logger.info("成功更新股票代碼列表")
except Exception as e:
logger.error(f"更新股票代碼列表時出錯: {e!s}")
# Circuit breakers for external APIs
twstock_circuit_breaker = CircuitBreaker(
failure_threshold=3,
timeout_seconds=30.0,
expected_exception=(ConnectionError, TimeoutError, Exception)
)
class StockService:
"""股票資料服務,負責從外部API抓取資料並提供快取功能"""
def __init__(self, cache_config: Optional[CacheConfig] = None, http_pool: Optional[HTTPConnectionPool] = None):
"""初始化股票服務
Args:
cache_config: 快取配置,None時使用預設配置
http_pool: HTTP連線池,None時使用全域池
"""
self.cache = CacheService(config=cache_config)
self.cache_ttl = {
'stock_data': 86400, # 24小時
'price_data': 1800, # 30分鐘
'realtime': 60, # 1分鐘
'best_four_points': 3600 # 1小時
}
self.circuit_breaker = twstock_circuit_breaker
self._http_pool = http_pool
self._performance_monitor = get_global_monitor()
async def _get_http_pool(self) -> HTTPConnectionPool:
"""獲取HTTP連線池"""
if self._http_pool is None:
self._http_pool = await get_global_pool()
return self._http_pool
def _get_cached_data(self, key: str) -> Optional[Dict[str, Any]]:
"""從快取中獲取資料"""
try:
return self.cache.get(key)
except Exception as e:
logger.warning(f"Cache get failed for key {key}: {e}")
return None
def _set_cached_data(self, key: str, data: Dict[str, Any], ttl: int, tags: Optional[List[str]] = None) -> None:
"""將資料存入快取
Args:
key: 快取鍵
data: 資料
ttl: 過期時間(秒)
tags: 快取標籤
"""
try:
self.cache.set(key, data, expire=ttl, tags=tags)
except Exception as e:
logger.warning(f"Cache set failed for key {key}: {e}")
# Don't raise exception for cache failures - continue without caching
@with_async_error_handling(operation="fetch_stock_data")
@with_retry(max_retries=2, base_delay=1.0)
async def fetch_stock_data(self, stock_code: str) -> dict[str, Any]:
"""
抓取股票基本資料
Args:
stock_code: 股票代號
Returns:
股票基本資料
Raises:
InvalidStockCodeError: 股票代號格式錯誤
StockNotFoundError: 找不到股票
StockDataUnavailableError: 股票資料無法取得
"""
# Track request for performance monitoring
self._performance_monitor.record_request()
# Validate stock code first
validated_code = StockCodeValidator.validate_stock_code(stock_code)
cache_key = f"stock_data_{validated_code}"
cached_data = self._get_cached_data(cache_key)
if cached_data:
return cached_data
# Use circuit breaker for external API call
try:
stock_info = await self.circuit_breaker.acall(
asyncio.to_thread, twstock.codes.get, validated_code
)
if not stock_info:
raise StockNotFoundError(
stock_code=validated_code,
message=f"Stock code '{validated_code}' not found in Taiwan stock exchange"
)
data = {
"stock_code": validated_code,
"name": stock_info.name,
"type": stock_info.type,
"isin": stock_info.ISIN,
"start_date": stock_info.start,
"market": stock_info.market,
"industry": stock_info.group,
"updated_at": datetime.now().isoformat()
}
# 使用標籤進行分組管理
tags = ["stock_data", stock_info.market, stock_info.type]
self._set_cached_data(cache_key, data, self.cache_ttl['stock_data'], tags)
return data
except StockNotFoundError:
# Track error and re-raise specific stock errors
self._performance_monitor.record_error()
raise
except Exception as e:
# Track error and convert generic exceptions to appropriate stock errors
self._performance_monitor.record_error()
if "connection" in str(e).lower() or "timeout" in str(e).lower():
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="stock information",
message=f"Unable to fetch stock data due to network issues: {str(e)}"
)
else:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="stock information",
message=f"Failed to fetch stock data: {str(e)}"
)
@with_async_error_handling(operation="fetch_price_data")
@with_retry(max_retries=2, base_delay=1.0)
async def fetch_price_data(self, stock_code: str, days: int = 30) -> list[dict[str, Any]]:
"""
抓取股票價格資料
Args:
stock_code: 股票代號
days: 資料天數
Returns:
價格資料列表
Raises:
InvalidStockCodeError: 股票代號格式錯誤
StockNotFoundError: 找不到股票
StockDataUnavailableError: 價格資料無法取得
"""
# Track request for performance monitoring
self._performance_monitor.record_request()
# Validate inputs
validated_code = StockCodeValidator.validate_stock_code(stock_code)
if days <= 0 or days > 3650: # Max 10 years
from tw_stock_agent.exceptions import RangeValidationError
raise RangeValidationError(
field_name="days",
value=days,
min_value=1,
max_value=3650
)
cache_key = f"price_data_{validated_code}_{days}"
cached_data = self._get_cached_data(cache_key)
if cached_data:
# Handle both old and new cache format
if isinstance(cached_data, dict) and "data" in cached_data:
return cached_data["data"]
return cached_data
try:
# Use circuit breaker for external API calls
stock = await self.circuit_breaker.acall(
asyncio.to_thread, Stock, validated_code
)
await self.circuit_breaker.acall(
asyncio.to_thread,
stock.fetch_from,
datetime.now().year,
datetime.now().month
)
if not stock.data:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="price data",
message="No price data available for this stock"
)
# 只取最近days天的資料
recent_data = stock.data[-days:] if len(stock.data) > days else stock.data
data = [{
"date": data.date.isoformat(),
"open": data.open,
"high": data.high,
"low": data.low,
"close": data.close,
"volume": data.capacity,
"change": data.change
} for data in recent_data]
# 使用標籤進行分組管理
tags = ["price_data", f"stock_{validated_code}", f"days_{days}"]
cache_data = {
"data": data,
"metadata": {
"stock_code": validated_code,
"days": days,
"actual_records": len(data)
}
}
self._set_cached_data(cache_key, cache_data, self.cache_ttl['price_data'], tags)
return data
except (StockDataUnavailableError, StockNotFoundError):
# Track error and re-raise specific stock errors
self._performance_monitor.record_error()
raise
except Exception as e:
# Track error and convert generic exceptions to appropriate stock errors
self._performance_monitor.record_error()
if "connection" in str(e).lower() or "timeout" in str(e).lower():
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="price data",
message=f"Unable to fetch price data due to network issues: {str(e)}"
)
else:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="price data",
message=f"Failed to fetch price data: {str(e)}"
)
@with_async_error_handling(operation="get_best_four_points")
@with_retry(max_retries=2, base_delay=1.0)
async def get_best_four_points(self, stock_code: str) -> dict[str, Any]:
"""
獲取四大買賣點分析
Args:
stock_code: 股票代號
Returns:
四大買賣點分析結果
Raises:
InvalidStockCodeError: 股票代號格式錯誤
StockNotFoundError: 找不到股票
StockDataUnavailableError: 分析資料無法取得
"""
# Track request for performance monitoring
self._performance_monitor.record_request()
# Validate stock code
validated_code = StockCodeValidator.validate_stock_code(stock_code)
cache_key = f"best_four_points_{validated_code}"
cached_data = self._get_cached_data(cache_key)
if cached_data:
return cached_data
try:
# Use circuit breaker for external API calls
stock = await self.circuit_breaker.acall(
asyncio.to_thread, Stock, validated_code
)
bfp = await self.circuit_breaker.acall(
asyncio.to_thread, BestFourPoint, stock
)
# Check if we have enough data for analysis
if not stock.data or len(stock.data) < 20: # Need at least 20 days for meaningful analysis
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="analysis data",
message="Insufficient historical data for Best Four Points analysis (minimum 20 days required)"
)
data = {
"stock_code": validated_code,
"buy_points": bfp.best_four_point_to_buy(),
"sell_points": bfp.best_four_point_to_sell(),
"analysis": bfp.best_four_point(),
"updated_at": datetime.now().isoformat(),
"data_points": len(stock.data)
}
# 使用標籤進行分組管理
tags = ["best_four_points", f"stock_{validated_code}", "analysis"]
self._set_cached_data(cache_key, data, self.cache_ttl['best_four_points'], tags)
return data
except (StockDataUnavailableError, StockNotFoundError):
# Track error and re-raise specific stock errors
self._performance_monitor.record_error()
raise
except Exception as e:
# Track error and convert generic exceptions to appropriate stock errors
self._performance_monitor.record_error()
if "connection" in str(e).lower() or "timeout" in str(e).lower():
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="analysis data",
message=f"Unable to perform analysis due to network issues: {str(e)}"
)
else:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="analysis data",
message=f"Failed to perform Best Four Points analysis: {str(e)}"
)
@with_async_error_handling(operation="get_realtime_data")
@with_retry(max_retries=1, base_delay=0.5) # Shorter retry for real-time data
async def get_realtime_data(self, stock_code: str) -> dict[str, Any]:
"""
獲取即時股票資訊
Args:
stock_code: 股票代號
Returns:
即時股票資訊
Raises:
InvalidStockCodeError: 股票代號格式錯誤
StockNotFoundError: 找不到股票
StockDataUnavailableError: 即時資料無法取得
StockMarketClosedError: 股市休市
"""
# Track request for performance monitoring
self._performance_monitor.record_request()
# Validate stock code
validated_code = StockCodeValidator.validate_stock_code(stock_code)
cache_key = f"realtime_{validated_code}"
cached_data = self._get_cached_data(cache_key)
if cached_data:
return cached_data
try:
# Use circuit breaker for external API call
realtime_data = await self.circuit_breaker.acall(
asyncio.to_thread, twstock.realtime.get, validated_code
)
if not realtime_data:
raise StockNotFoundError(
stock_code=validated_code,
message=f"No real-time data available for stock '{validated_code}'"
)
if realtime_data.get('success', False) is False:
# Check if market is closed
current_hour = datetime.now().hour
if current_hour < 9 or current_hour > 14: # Taiwan market hours: 9:00-13:30
from tw_stock_agent.exceptions import StockMarketClosedError
raise StockMarketClosedError(
stock_code=validated_code,
message="Taiwan stock market is closed. Trading hours: 9:00 AM - 1:30 PM (GMT+8)"
)
else:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="real-time data",
message="Real-time data temporarily unavailable"
)
realtime_info = realtime_data.get('realtime', {})
stock_info = realtime_data.get('info', {})
data = {
"stock_code": validated_code,
"name": stock_info.get('name'),
"current_price": realtime_info.get('latest_trade_price'),
"open": realtime_info.get('open'),
"high": realtime_info.get('high'),
"low": realtime_info.get('low'),
"volume": realtime_info.get('accumulate_trade_volume'),
"updated_at": datetime.now().isoformat(),
"market_status": "open" if realtime_data.get('success') else "closed"
}
# 使用標籤進行分組管理
tags = ["realtime", f"stock_{validated_code}", "live_data"]
self._set_cached_data(cache_key, data, self.cache_ttl['realtime'], tags)
return data
except (StockDataUnavailableError, StockNotFoundError):
# Track error and re-raise specific stock errors
self._performance_monitor.record_error()
raise
except Exception as e:
# Track error and convert generic exceptions to appropriate stock errors
self._performance_monitor.record_error()
if "connection" in str(e).lower() or "timeout" in str(e).lower():
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="real-time data",
message=f"Unable to fetch real-time data due to network issues: {str(e)}"
)
else:
raise StockDataUnavailableError(
stock_code=validated_code,
data_type="real-time data",
message=f"Failed to fetch real-time data: {str(e)}"
)
@with_async_error_handling(operation="fetch_multiple_stocks_data")
async def fetch_multiple_stocks_data(self, stock_codes: List[str]) -> Dict[str, Dict[str, Any]]:
"""
批量獲取多支股票的基本資料
Args:
stock_codes: 股票代號列表
Returns:
股票基本資料字典,以股票代號為鍵
Raises:
ParameterValidationError: 股票代號列表格式錯誤
"""
if not stock_codes:
from tw_stock_agent.exceptions import ParameterValidationError
raise ParameterValidationError(
parameter_name="stock_codes",
parameter_value=stock_codes,
expected_format="Non-empty list of stock codes"
)
# Validate all stock codes first
validated_codes = StockCodeValidator.validate_multiple_codes(stock_codes, strict=False)
# 構建快取鍵
cache_keys = [f"stock_data_{code}" for code in validated_codes]
# 批量獲取已快取的資料
try:
cached_results = self.cache.get_bulk(cache_keys)
except Exception as e:
logger.warning(f"Bulk cache get failed: {e}")
cached_results = {}
# 分離已快取和需要獲取的股票
results = {}
missing_codes = []
for i, stock_code in enumerate(validated_codes):
cache_key = cache_keys[i]
cached_data = cached_results.get(cache_key)
if cached_data:
results[stock_code] = cached_data
else:
missing_codes.append(stock_code)
# 獲取未快取的股票資料
if missing_codes:
tasks = [self.fetch_stock_data(code) for code in missing_codes]
fresh_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(fresh_results):
stock_code = missing_codes[i]
if isinstance(result, Exception):
logger.error(f"獲取股票 {stock_code} 資料時出錯: {result}")
# Convert exception to error dict for batch operations
if isinstance(result, TwStockAgentError):
error_dict = result.to_dict()
error_dict["stock_code"] = stock_code
results[stock_code] = error_dict
else:
results[stock_code] = {
"stock_code": stock_code,
"error": f"獲取資料失敗: {result!s}",
"error_code": "FETCH_FAILED"
}
else:
results[stock_code] = result
return results
@with_async_error_handling(operation="fetch_multiple_realtime_data")
async def fetch_multiple_realtime_data(self, stock_codes: List[str]) -> Dict[str, Dict[str, Any]]:
"""
批量獲取多支股票的即時資料
Args:
stock_codes: 股票代號列表
Returns:
即時資料字典,以股票代號為鍵
Raises:
ParameterValidationError: 股票代號列表格式錯誤
"""
if not stock_codes:
from tw_stock_agent.exceptions import ParameterValidationError
raise ParameterValidationError(
parameter_name="stock_codes",
parameter_value=stock_codes,
expected_format="Non-empty list of stock codes"
)
# Validate all stock codes first
validated_codes = StockCodeValidator.validate_multiple_codes(stock_codes, strict=False)
# 構建快取鍵
cache_keys = [f"realtime_{code}" for code in validated_codes]
# 批量獲取已快取的資料
try:
cached_results = self.cache.get_bulk(cache_keys)
except Exception as e:
logger.warning(f"Bulk cache get failed: {e}")
cached_results = {}
# 分離已快取和需要獲取的股票
results = {}
missing_codes = []
for i, stock_code in enumerate(validated_codes):
cache_key = cache_keys[i]
cached_data = cached_results.get(cache_key)
if cached_data:
results[stock_code] = cached_data
else:
missing_codes.append(stock_code)
# 獲取未快取的即時資料
if missing_codes:
tasks = [self.get_realtime_data(code) for code in missing_codes]
fresh_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(fresh_results):
stock_code = missing_codes[i]
if isinstance(result, Exception):
logger.error(f"獲取股票 {stock_code} 即時資料時出錯: {result}")
# Convert exception to error dict for batch operations
if isinstance(result, TwStockAgentError):
error_dict = result.to_dict()
error_dict["stock_code"] = stock_code
results[stock_code] = error_dict
else:
results[stock_code] = {
"stock_code": stock_code,
"error": f"獲取即時資料失敗: {result!s}",
"error_code": "FETCH_FAILED"
}
else:
results[stock_code] = result
return results
def invalidate_stock_cache(self, stock_code: str) -> int:
"""
清除特定股票的所有快取資料
Args:
stock_code: 股票代號
Returns:
清除的快取項目數量
"""
pattern = f"%{stock_code}%"
return self.cache.delete_by_pattern(pattern)
def invalidate_cache_by_type(self, cache_type: str) -> int:
"""
根據快取類型清除快取資料
Args:
cache_type: 快取類型 (stock_data, price_data, realtime, best_four_points)
Returns:
清除的快取項目數量
"""
return self.cache.delete_by_tags([cache_type])
def invalidate_market_cache(self, market: str) -> int:
"""
清除特定市場的快取資料
Args:
market: 市場名稱
Returns:
清除的快取項目數量
"""
return self.cache.delete_by_tags([market])
def warm_popular_stocks_cache(self, stock_codes: List[str]) -> Dict[str, int]:
"""
預熱熱門股票的快取資料
Args:
stock_codes: 熱門股票代號列表
Returns:
預熱結果統計
"""
warm_data = {}
# 準備股票基本資料
for stock_code in stock_codes:
try:
stock_info = twstock.codes.get(stock_code)
if stock_info:
data = {
"stock_code": stock_code,
"name": stock_info.name,
"type": stock_info.type,
"isin": stock_info.ISIN,
"start_date": stock_info.start,
"market": stock_info.market,
"industry": stock_info.group,
"updated_at": datetime.now().isoformat()
}
warm_data[f"stock_data_{stock_code}"] = data
except Exception as e:
logger.error(f"準備股票 {stock_code} 預熱資料時出錯: {e}")
# 執行快取預熱
warmed_count = self.cache.warm_cache(warm_data, self.cache_ttl['stock_data'])
return {
"total_requested": len(stock_codes),
"successfully_warmed": warmed_count,
"cache_hit_improvement": warmed_count / len(stock_codes) if stock_codes else 0
}
def get_cache_statistics(self) -> Dict[str, Any]:
"""
獲取快取統計資料
Returns:
快取統計資料
"""
stats = self.cache.get_stats()
# 獲取各類型快取的數量
cache_breakdown = {
"stock_data": len(self.cache.get_keys_by_tags(["stock_data"])),
"price_data": len(self.cache.get_keys_by_tags(["price_data"])),
"realtime": len(self.cache.get_keys_by_tags(["realtime"])),
"best_four_points": len(self.cache.get_keys_by_tags(["best_four_points"])),
}
return {
"hit_rate": stats.hit_rate,
"total_hits": stats.hits,
"total_misses": stats.misses,
"total_sets": stats.sets,
"total_deletes": stats.deletes,
"total_cleanups": stats.cleanups,
"total_size_bytes": stats.total_size,
"cache_breakdown": cache_breakdown
}
def cleanup_old_cache(self, older_than_hours: int = 24) -> Dict[str, int]:
"""
清理舊的快取資料
Args:
older_than_hours: 清理多少小時前的資料
Returns:
清理結果統計
"""
results = {
"expired_cleaned": self.cache.cleanup_expired(),
"lru_evicted": 0
}
# 如果需要,可以執行更積極的清理
# 這裡可以根據業務需求添加更多清理邏輯
return results
@with_async_error_handling(operation="refresh_stock_cache")
async def refresh_stock_cache(self, stock_code: str) -> Dict[str, Any]:
"""
強制刷新特定股票的所有快取資料
Args:
stock_code: 股票代號
Returns:
刷新結果
Raises:
InvalidStockCodeError: 股票代號格式錯誤
"""
# Validate stock code first
validated_code = StockCodeValidator.validate_stock_code(stock_code)
# 清除現有快取
invalidated = self.invalidate_stock_cache(validated_code)
# 重新獲取所有類型的資料
tasks = [
self.fetch_stock_data(validated_code),
self.get_realtime_data(validated_code),
self.get_best_four_points(validated_code),
self.fetch_price_data(validated_code, days=30)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
refresh_results = {
"stock_code": validated_code,
"invalidated_count": invalidated,
"refreshed_data": {
"stock_data": not isinstance(results[0], Exception),
"realtime": not isinstance(results[1], Exception),
"best_four_points": not isinstance(results[2], Exception),
"price_data": not isinstance(results[3], Exception),
},
"errors": [
r.to_dict() if isinstance(r, TwStockAgentError) else {"message": str(r)}
for r in results if isinstance(r, Exception)
]
}
return refresh_results
def backup_cache(self, backup_path: str) -> bool:
"""
備份快取資料庫
Args:
backup_path: 備份檔案路徑
Returns:
是否備份成功
"""
return self.cache.backup_cache(backup_path)
def restore_cache(self, backup_path: str) -> bool:
"""
還原快取資料庫
Args:
backup_path: 備份檔案路徑
Returns:
是否還原成功
"""
return self.cache.restore_cache(backup_path)
def get_performance_metrics(self) -> Dict[str, Any]:
"""獲取性能指標"""
metrics = {
"cache_stats": self.get_cache_statistics(),
"performance_summary": self._performance_monitor.get_performance_summary() if self._performance_monitor else None
}
# Add HTTP pool metrics if available
if self._http_pool:
metrics["http_pool"] = self._http_pool.get_metrics()
return metrics
async def close(self) -> None:
"""關閉股票服務和快取連線"""
# Close HTTP pool if we own it
if self._http_pool:
await self._http_pool.close()
self.cache.close()
logger.info("股票服務已關閉")
def __enter__(self):
"""Context manager 支援"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager 支援"""
# For backward compatibility, call sync version of close
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're in an async context, schedule the close
asyncio.create_task(self.close())
else:
loop.run_until_complete(self.close())
except RuntimeError:
# No event loop, fallback to sync close
self.cache.close()
async def __aenter__(self):
"""Async context manager support"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager support"""
await self.close()