"""BM25 keyword-based search for tool discovery."""
import re
import logging
from typing import TYPE_CHECKING
from rank_bm25 import BM25Okapi
from ..config import config
if TYPE_CHECKING:
from ..catalog import ToolCatalog, ToolDefinition
logger = logging.getLogger(__name__)
class BM25Search:
"""BM25-based keyword search for tools.
Uses the BM25Okapi algorithm to rank tools based on term frequency
and inverse document frequency. Optimized for short queries typical
of tool search.
"""
def __init__(self, catalog: "ToolCatalog"):
"""Initialize BM25 search engine.
Args:
catalog: The tool catalog to search
"""
self.catalog = catalog
self._index: BM25Okapi | None = None
self._tool_names: list[str] = []
self._corpus: list[list[str]] = []
# Register for catalog updates
self.catalog.on_update(lambda _: self._mark_dirty())
self._dirty = True
def _mark_dirty(self) -> None:
"""Mark the index as needing rebuild."""
self._dirty = True
def _tokenize(self, text: str) -> list[str]:
"""Tokenize text for indexing and search.
Args:
text: Text to tokenize
Returns:
List of lowercase tokens
"""
# Convert to lowercase and split on non-alphanumeric characters
tokens = re.findall(r"\w+", text.lower())
return tokens
def rebuild_index(self) -> None:
"""Rebuild the BM25 index from the catalog."""
tools = self.catalog.list_tools()
if not tools:
self._index = None
self._tool_names = []
self._corpus = []
self._dirty = False
logger.info("BM25 index cleared (no tools)")
return
# Build corpus
self._tool_names = [tool.name for tool in tools]
self._corpus = [self._tokenize(tool.to_searchable_text()) for tool in tools]
# Create BM25 index with tuned parameters
self._index = BM25Okapi(
self._corpus,
k1=config.BM25_K1, # Term saturation parameter
b=config.BM25_B, # Document length normalization
)
self._dirty = False
logger.info(f"BM25 index rebuilt with {len(tools)} tools")
def _ensure_index(self) -> None:
"""Ensure the index is up-to-date."""
if self._dirty or self._index is None:
self.rebuild_index()
def search(self, query: str, top_k: int | None = None) -> list["ToolDefinition"]:
"""Search for tools matching the query.
Args:
query: Natural language search query
top_k: Maximum number of results (defaults to config)
Returns:
List of matching tools, ranked by relevance
"""
if top_k is None:
top_k = config.MAX_SEARCH_RESULTS
self._ensure_index()
if not self._index or not self._tool_names:
return []
# Tokenize query
query_tokens = self._tokenize(query)
if not query_tokens:
return []
# Get BM25 scores
scores = self._index.get_scores(query_tokens)
# Get top-k indices sorted by score (descending)
scored_indices = sorted(
enumerate(scores),
key=lambda x: x[1],
reverse=True
)[:top_k]
# Filter out zero-score results and get tool names
matching_names = [
self._tool_names[idx]
for idx, score in scored_indices
if score > 0
]
# Retrieve tool definitions
return self.catalog.get_tools_by_names(matching_names)
def search_names(self, query: str, top_k: int | None = None) -> list[str]:
"""Search and return only tool names.
Args:
query: Natural language search query
top_k: Maximum number of results
Returns:
List of matching tool names
"""
tools = self.search(query, top_k)
return [tool.name for tool in tools]