"""
参数验证工具
提供统一的参数验证功能。
支持 MCP 客户端将参数序列化为字符串的情况。
"""
from datetime import datetime
from typing import List, Optional, Union
import os
import json
import yaml
import ast
from .errors import InvalidParameterError
from .date_parser import DateParser
# ==================== 辅助函数:处理字符串序列化 ====================
def _parse_string_to_list(value: str) -> List[str]:
"""
将字符串解析为列表
支持格式:
- JSON 数组: '["zhihu", "weibo"]'
- Python 列表字符串: "['zhihu', 'weibo']"
- 逗号分隔: "zhihu, weibo" 或 "zhihu,weibo"
Args:
value: 字符串值
Returns:
解析后的列表
Raises:
InvalidParameterError: 解析失败
"""
value = value.strip()
if not value:
return []
# 尝试 JSON 解析: '["zhihu", "weibo"]'
try:
parsed = json.loads(value)
if isinstance(parsed, list):
return [str(item) for item in parsed]
# 如果解析结果不是列表,继续尝试其他方式
except json.JSONDecodeError:
pass
# 尝试 Python 字面量解析: "['zhihu', 'weibo']"
try:
parsed = ast.literal_eval(value)
if isinstance(parsed, list):
return [str(item) for item in parsed]
if isinstance(parsed, str):
# 单个字符串,包装成列表
return [parsed]
except (ValueError, SyntaxError):
pass
# 尝试逗号分隔: "zhihu, weibo" 或 "zhihu,weibo"
if ',' in value:
items = [item.strip() for item in value.split(',')]
return [item for item in items if item]
# 单个值
return [value]
def _parse_string_to_int(value: str, param_name: str = "参数") -> int:
"""
将字符串解析为整数
Args:
value: 字符串值
param_name: 参数名(用于错误消息)
Returns:
解析后的整数
Raises:
InvalidParameterError: 解析失败
"""
value = value.strip()
try:
# 尝试直接转换
return int(value)
except ValueError:
pass
# 尝试解析浮点数后取整
try:
return int(float(value))
except ValueError:
raise InvalidParameterError(
f"{param_name} 必须是整数,无法解析: {value}",
suggestion=f"请提供有效的整数值,如: 10, 50, 100"
)
def _parse_string_to_float(value: str, param_name: str = "参数") -> float:
"""
将字符串解析为浮点数
Args:
value: 字符串值
param_name: 参数名(用于错误消息)
Returns:
解析后的浮点数
Raises:
InvalidParameterError: 解析失败
"""
value = value.strip()
try:
return float(value)
except ValueError:
raise InvalidParameterError(
f"{param_name} 必须是数字,无法解析: {value}",
suggestion=f"请提供有效的数字值,如: 0.6, 3.0"
)
def _parse_string_to_bool(value: str) -> bool:
"""
将字符串解析为布尔值
Args:
value: 字符串值
Returns:
解析后的布尔值
"""
value = value.strip().lower()
if value in ('true', '1', 'yes', 'on'):
return True
elif value in ('false', '0', 'no', 'off', ''):
return False
else:
# 默认非空字符串为 True
return bool(value)
def get_supported_platforms() -> List[str]:
"""
从 config.yaml 动态获取支持的平台列表
Returns:
平台ID列表
Note:
- 读取失败时返回空列表,允许所有平台通过(降级策略)
- 平台列表来自 config/config.yaml 中的 platforms 配置
"""
try:
# 获取 config.yaml 路径(相对于当前文件)
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "..", "..", "config", "config.yaml")
config_path = os.path.normpath(config_path)
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
platforms_config = config.get('platforms', {})
# 处理嵌套结构:{enabled: bool, sources: [...]}
sources = platforms_config.get('sources', [])
return [p['id'] for p in sources if 'id' in p]
except Exception as e:
# 降级方案:返回空列表,允许所有平台
print(f"警告:无法加载平台配置 ({config_path}): {e}")
return []
def validate_platforms(platforms: Optional[Union[List[str], str]]) -> List[str]:
"""
验证平台列表
Args:
platforms: 平台ID列表或字符串,None表示使用 config.yaml 中配置的所有平台
支持多种格式:
- None: 使用默认平台
- ["zhihu", "weibo"]: JSON 数组
- '["zhihu", "weibo"]': JSON 数组字符串
- "['zhihu', 'weibo']": Python 列表字符串
- "zhihu, weibo": 逗号分隔字符串
- "zhihu": 单个平台字符串
Returns:
验证后的平台列表
Raises:
InvalidParameterError: 平台不支持
Note:
- platforms=None 时,返回 config.yaml 中配置的平台列表
- 会验证平台ID是否在 config.yaml 的 platforms 配置中
- 配置加载失败时,允许所有平台通过(降级策略)
"""
supported_platforms = get_supported_platforms()
if platforms is None:
# 返回配置文件中的平台列表(用户的默认配置)
return supported_platforms if supported_platforms else []
# 支持字符串形式的列表输入(某些 MCP 客户端会将 JSON 数组序列化为字符串)
if isinstance(platforms, str):
platforms = _parse_string_to_list(platforms)
if not platforms:
# 空字符串或解析后为空,使用默认平台
return supported_platforms if supported_platforms else []
if not isinstance(platforms, list):
raise InvalidParameterError("platforms 参数必须是列表类型")
if not platforms:
# 空列表时,返回配置文件中的平台列表
return supported_platforms if supported_platforms else []
# 如果配置加载失败(supported_platforms为空),允许所有平台通过
if not supported_platforms:
print("警告:平台配置未加载,跳过平台验证")
return platforms
# 验证每个平台是否在配置中
invalid_platforms = [p for p in platforms if p not in supported_platforms]
if invalid_platforms:
raise InvalidParameterError(
f"不支持的平台: {', '.join(invalid_platforms)}",
suggestion=f"支持的平台(来自config.yaml): {', '.join(supported_platforms)}"
)
return platforms
def validate_limit(limit: Optional[Union[int, str]], default: int = 20, max_limit: int = 1000) -> int:
"""
验证数量限制参数
Args:
limit: 限制数量(整数或字符串)
default: 默认值
max_limit: 最大限制
Returns:
验证后的限制值
Raises:
InvalidParameterError: 参数无效
"""
if limit is None:
return default
# 支持字符串形式的整数(某些 MCP 客户端会将数字序列化为字符串)
if isinstance(limit, str):
limit = _parse_string_to_int(limit, "limit")
if not isinstance(limit, int):
raise InvalidParameterError("limit 参数必须是整数类型")
if limit <= 0:
raise InvalidParameterError("limit 必须大于0")
if limit > max_limit:
raise InvalidParameterError(
f"limit 不能超过 {max_limit}",
suggestion=f"请使用分页或降低limit值"
)
return limit
def validate_date(date_str: str) -> datetime:
"""
验证日期格式
Args:
date_str: 日期字符串 (YYYY-MM-DD)
Returns:
datetime对象
Raises:
InvalidParameterError: 日期格式错误
"""
try:
return datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
raise InvalidParameterError(
f"日期格式错误: {date_str}",
suggestion="请使用 YYYY-MM-DD 格式,例如: 2025-10-11"
)
def normalize_date_range(date_range: Optional[Union[dict, str]]) -> Optional[Union[dict, str]]:
"""
规范化 date_range 参数
某些 MCP 客户端(特别是 HTTP 方式)会将 JSON 对象序列化为字符串传入。
此函数尝试将 JSON 字符串解析为 dict,如果不是 JSON 格式则保持原样。
Args:
date_range: 日期范围,可能是:
- dict: {"start": "2025-01-01", "end": "2025-01-07"}
- JSON 字符串: '{"start": "2025-01-01", "end": "2025-01-07"}'
- 普通字符串: "今天", "昨天", "2025-01-01"
- None
Returns:
规范化后的 date_range(dict 或普通字符串)
Examples:
>>> normalize_date_range('{"start":"2025-01-01","end":"2025-01-07"}')
{"start": "2025-01-01", "end": "2025-01-07"}
>>> normalize_date_range("今天")
"今天"
>>> normalize_date_range({"start": "2025-01-01", "end": "2025-01-07"})
{"start": "2025-01-01", "end": "2025-01-07"}
"""
if date_range is None:
return None
# 如果已经是 dict,直接返回
if isinstance(date_range, dict):
return date_range
# 如果是字符串,尝试解析为 JSON
if isinstance(date_range, str):
# 检查是否看起来像 JSON 对象
stripped = date_range.strip()
if stripped.startswith('{') and stripped.endswith('}'):
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass # 解析失败,当作普通字符串处理
return date_range
def validate_date_range(date_range: Optional[Union[dict, str]]) -> Optional[tuple]:
"""
验证日期范围
Args:
date_range: 日期范围字典或JSON字符串 {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
Returns:
(start_date, end_date) 元组,或 None
Raises:
InvalidParameterError: 日期范围无效
"""
if date_range is None:
return None
# 支持字符串形式的JSON输入(某些MCP客户端会将JSON对象序列化为字符串)
if isinstance(date_range, str):
try:
date_range = json.loads(date_range)
except json.JSONDecodeError as e:
raise InvalidParameterError(
f"date_range JSON 解析失败: {e}",
suggestion='请使用正确的JSON格式: {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}'
)
if not isinstance(date_range, dict):
raise InvalidParameterError(
"date_range 必须是字典类型或有效的JSON字符串",
suggestion='例如: {"start": "2025-10-01", "end": "2025-10-11"}'
)
start_str = date_range.get("start")
end_str = date_range.get("end")
if not start_str or not end_str:
raise InvalidParameterError(
"date_range 必须包含 start 和 end 字段",
suggestion='例如: {"start": "2025-10-01", "end": "2025-10-11"}'
)
start_date = validate_date(start_str)
end_date = validate_date(end_str)
if start_date > end_date:
raise InvalidParameterError(
"开始日期不能晚于结束日期",
suggestion=f"start: {start_str}, end: {end_str}"
)
# 检查日期是否在未来
today = datetime.now().date()
if start_date.date() > today or end_date.date() > today:
# 获取可用日期范围提示
try:
from ..services.data_service import DataService
data_service = DataService()
earliest, latest = data_service.get_available_date_range()
if earliest and latest:
available_range = f"{earliest.strftime('%Y-%m-%d')} 至 {latest.strftime('%Y-%m-%d')}"
else:
available_range = "无可用数据"
except Exception:
available_range = "未知(请检查 output 目录)"
future_dates = []
if start_date.date() > today:
future_dates.append(start_str)
if end_date.date() > today and end_str != start_str:
future_dates.append(end_str)
raise InvalidParameterError(
f"不允许查询未来日期: {', '.join(future_dates)}(当前日期: {today.strftime('%Y-%m-%d')})",
suggestion=f"当前可用数据范围: {available_range}"
)
return (start_date, end_date)
def validate_keyword(keyword: str) -> str:
"""
验证关键词
Args:
keyword: 搜索关键词
Returns:
处理后的关键词
Raises:
InvalidParameterError: 关键词无效
"""
if not keyword:
raise InvalidParameterError("keyword 不能为空")
if not isinstance(keyword, str):
raise InvalidParameterError("keyword 必须是字符串类型")
keyword = keyword.strip()
if not keyword:
raise InvalidParameterError("keyword 不能为空白字符")
if len(keyword) > 100:
raise InvalidParameterError(
"keyword 长度不能超过100个字符",
suggestion="请使用更简洁的关键词"
)
return keyword
def validate_top_n(top_n: Optional[Union[int, str]], default: int = 10) -> int:
"""
验证TOP N参数
Args:
top_n: TOP N数量(整数或字符串)
default: 默认值
Returns:
验证后的值
Raises:
InvalidParameterError: 参数无效
"""
return validate_limit(top_n, default=default, max_limit=100)
def validate_mode(mode: Optional[str], valid_modes: List[str], default: str) -> str:
"""
验证模式参数
Args:
mode: 模式字符串
valid_modes: 有效模式列表
default: 默认模式
Returns:
验证后的模式
Raises:
InvalidParameterError: 模式无效
"""
if mode is None:
return default
if not isinstance(mode, str):
raise InvalidParameterError("mode 必须是字符串类型")
if mode not in valid_modes:
raise InvalidParameterError(
f"无效的模式: {mode}",
suggestion=f"支持的模式: {', '.join(valid_modes)}"
)
return mode
def validate_config_section(section: Optional[str]) -> str:
"""
验证配置节参数
Args:
section: 配置节名称
Returns:
验证后的配置节
Raises:
InvalidParameterError: 配置节无效
"""
valid_sections = ["all", "crawler", "push", "keywords", "weights"]
return validate_mode(section, valid_sections, "all")
def validate_threshold(
threshold: Optional[Union[float, int, str]],
default: float = 0.6,
min_value: float = 0.0,
max_value: float = 1.0,
param_name: str = "threshold"
) -> float:
"""
验证阈值参数(浮点数)
Args:
threshold: 阈值(浮点数、整数或字符串)
default: 默认值
min_value: 最小值
max_value: 最大值
param_name: 参数名(用于错误消息)
Returns:
验证后的阈值
Raises:
InvalidParameterError: 参数无效
"""
if threshold is None:
return default
# 支持字符串形式的数字(某些 MCP 客户端会将数字序列化为字符串)
if isinstance(threshold, str):
threshold = _parse_string_to_float(threshold, param_name)
# 整数转浮点数
if isinstance(threshold, int):
threshold = float(threshold)
if not isinstance(threshold, float):
raise InvalidParameterError(
f"{param_name} 必须是数字类型",
suggestion=f"请提供 {min_value} 到 {max_value} 之间的数字"
)
if threshold < min_value or threshold > max_value:
raise InvalidParameterError(
f"{param_name} 必须在 {min_value} 到 {max_value} 之间,当前值: {threshold}",
suggestion=f"推荐值: {default}"
)
return threshold
def validate_date_query(
date_query: str,
allow_future: bool = False,
max_days_ago: int = 365
) -> datetime:
"""
验证并解析日期查询字符串
Args:
date_query: 日期查询字符串
allow_future: 是否允许未来日期
max_days_ago: 允许查询的最大天数
Returns:
解析后的datetime对象
Raises:
InvalidParameterError: 日期查询无效
Examples:
>>> validate_date_query("昨天")
datetime(2025, 10, 10)
>>> validate_date_query("2025-10-10")
datetime(2025, 10, 10)
"""
if not date_query:
raise InvalidParameterError(
"日期查询字符串不能为空",
suggestion="请提供日期查询,如:今天、昨天、2025-10-10"
)
# 使用DateParser解析日期
parsed_date = DateParser.parse_date_query(date_query)
# 验证日期不在未来
if not allow_future:
DateParser.validate_date_not_future(parsed_date)
# 验证日期不太久远
DateParser.validate_date_not_too_old(parsed_date, max_days=max_days_ago)
return parsed_date