reranker.py•1.84 kB
from typing import Any, ClassVar
from pydantic import BaseModel, Field, ConfigDict
class BaseRerankerConfig(BaseModel):
"""Base configuration class for rerankers."""
model_config = ConfigDict(extra="allow", strict=False)
model_name_or_path: str = Field(..., description="Model name or path for the reranker")
top_k: int = Field(default=10, description="Number of top results to return after reranking")
class SiliconFlowRerankerConfig(BaseRerankerConfig):
"""Configuration class for SiliconFlow API reranker."""
api_key: str = Field(..., description="SiliconFlow API key")
api_base: str = Field(default="https://api.siliconflow.cn/v1", description="Base URL for SiliconFlow API")
model_name: str = Field(default="Qwen/Qwen3-Reranker-0.6B", description="Reranker model name")
max_chunks_per_query: int = Field(default=100, description="Maximum number of chunks to rerank per query")
class RerankerConfigFactory(BaseModel):
"""Factory for creating reranker configurations."""
backend: str = Field(..., description="Reranker backend type")
config: BaseRerankerConfig = Field(..., description="Reranker configuration")
backend_to_class: ClassVar[dict[str, Any]] = {
"siliconflow": SiliconFlowRerankerConfig,
}
def __init__(self, **data):
"""Initialize and validate the configuration."""
backend = data.get("backend")
config_data = data.get("config", {})
if backend not in self.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
config_class = self.backend_to_class[backend]
if isinstance(config_data, dict):
config_instance = config_class(**config_data)
else:
config_instance = config_data
super().__init__(backend=backend, config=config_instance)