siliconflow.py•5.63 kB
import os
from typing import List, Tuple
import openai
from memos.configs.reranker import SiliconFlowRerankerConfig
from memos.rerankers.base import BaseReranker
from memos.log import get_logger
logger = get_logger(__name__)
class SiliconFlowReranker(BaseReranker):
"""SiliconFlow API reranker implementation."""
def __init__(self, config: SiliconFlowRerankerConfig):
"""Initialize the SiliconFlow reranker with the given configuration."""
self.config = config
# 初始化OpenAI客户端,使用SiliconFlow API
self.client = openai.OpenAI(
api_key=config.api_key,
base_url=config.api_base
)
# 设置模型名称
self.model_name = config.model_name or "Qwen/Qwen3-Reranker-0.6B"
self.top_k = config.top_k
self.max_chunks_per_query = config.max_chunks_per_query
logger.info(f"Initialized SiliconFlow reranker with model: {self.model_name}")
def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[str, float]]:
"""
Rerank documents based on their relevance to the query using SiliconFlow API.
Args:
query: The search query
documents: List of document texts to rerank
top_k: Number of top results to return
Returns:
List of tuples (document, score) sorted by relevance score in descending order
"""
if not documents:
return []
if top_k is None:
top_k = self.top_k
try:
# 限制文档数量以避免API限制
if len(documents) > self.max_chunks_per_query:
logger.warning(f"Too many documents ({len(documents)}), truncating to {self.max_chunks_per_query}")
documents = documents[:self.max_chunks_per_query]
# 调用SiliconFlow重排API
# 使用requests直接调用,因为OpenAI客户端不支持重排API
import requests
response = requests.post(
f"{self.config.api_base}/rerank",
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.model_name,
"query": query,
"documents": documents,
"top_k": top_k,
"return_documents": True
}
)
# 检查响应状态
response.raise_for_status()
# 解析响应
results = response.json()
# 提取重排结果
reranked_docs = []
for result in results.get("results", []):
doc_text = result.get("document", {}).get("text", "")
score = result.get("relevance_score", 0.0)
reranked_docs.append((doc_text, score))
logger.debug(f"Reranked {len(documents)} documents, returned top {len(reranked_docs)}")
return reranked_docs
except Exception as e:
logger.error(f"Failed to rerank documents: {e}")
# 降级处理:返回原始文档顺序,分数为0
return [(doc, 0.0) for doc in documents[:top_k]]
def get_scores(self, query: str, documents: List[str]) -> List[float]:
"""
Get relevance scores for documents without reordering.
Args:
query: The search query
documents: List of document texts to score
Returns:
List of relevance scores corresponding to input documents
"""
if not documents:
return []
try:
# 限制文档数量
if len(documents) > self.max_chunks_per_query:
logger.warning(f"Too many documents ({len(documents)}), truncating to {self.max_chunks_per_query}")
documents = documents[:self.max_chunks_per_query]
# 调用SiliconFlow重排API获取分数
import requests
response = requests.post(
f"{self.config.api_base}/rerank",
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.model_name,
"query": query,
"documents": documents,
"return_documents": False # 只返回分数
}
)
# 检查响应状态
response.raise_for_status()
# 解析响应
results = response.json()
# 提取分数
scores = []
for result in results.get("results", []):
score = result.get("relevance_score", 0.0)
scores.append(score)
# 确保返回的分数数量与输入文档数量一致
while len(scores) < len(documents):
scores.append(0.0)
logger.debug(f"Generated scores for {len(documents)} documents")
return scores[:len(documents)]
except Exception as e:
logger.error(f"Failed to get document scores: {e}")
# 降级处理:返回全零分数
return [0.0] * len(documents)