"""Reranking utilities for the Crawl4AI MCP server."""
from typing import Any
from sentence_transformers import CrossEncoder
from src.core.logging import logger
def rerank_results(
model: CrossEncoder,
query: str,
results: list[dict[str, Any]],
content_key: str = "content",
) -> list[dict[str, Any]]:
"""
Rerank search results using a cross-encoder model.
Args:
model: The cross-encoder model to use for reranking
query: The search query
results: List of search results
content_key: The key in each result dict that contains the text content
Returns:
Reranked list of results
"""
if not model or not results:
return results
try:
# Extract content from results
texts = [result.get(content_key, "") for result in results]
# Create pairs of [query, document] for the cross-encoder
pairs = [[query, text] for text in texts]
# Get relevance scores from the cross-encoder
scores = model.predict(pairs)
# Add scores to results and sort by score (descending)
for i, result in enumerate(results):
result["rerank_score"] = float(scores[i])
# Sort by rerank score
return sorted(results, key=lambda x: x.get("rerank_score", 0), reverse=True)
except Exception as e:
logger.error(f"Error during reranking: {e}")
return results