"""Semantic embeddings-based search for tool discovery."""
import logging
from typing import TYPE_CHECKING
import numpy as np
from ..config import config
if TYPE_CHECKING:
from ..catalog import ToolCatalog, ToolDefinition
logger = logging.getLogger(__name__)
# Lazy load sentence-transformers to reduce startup time
_model = None
def _get_model():
"""Lazily load the sentence transformer model."""
global _model
if _model is None:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {config.EMBEDDING_MODEL}")
_model = SentenceTransformer(config.EMBEDDING_MODEL)
logger.info("Embedding model loaded")
return _model
class EmbeddingsSearch:
"""Semantic search using sentence embeddings.
Uses sentence-transformers to encode tools and queries into
dense vectors, then ranks by cosine similarity.
"""
def __init__(self, catalog: "ToolCatalog"):
"""Initialize embeddings search engine.
Args:
catalog: The tool catalog to search
"""
self.catalog = catalog
self._embeddings: np.ndarray | None = None
self._tool_names: list[str] = []
self._dirty = True
# Register for catalog updates
self.catalog.on_update(lambda _: self._mark_dirty())
def _mark_dirty(self) -> None:
"""Mark the embeddings as needing recomputation."""
self._dirty = True
def rebuild_index(self) -> None:
"""Rebuild embeddings for all tools in the catalog."""
tools = self.catalog.list_tools()
if not tools:
self._embeddings = None
self._tool_names = []
self._dirty = False
logger.info("Embeddings index cleared (no tools)")
return
model = _get_model()
# Get searchable text for each tool
self._tool_names = [tool.name for tool in tools]
texts = [tool.to_searchable_text() for tool in tools]
# Compute embeddings
self._embeddings = model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True, # Pre-normalize for cosine similarity
show_progress_bar=False,
)
self._dirty = False
logger.info(f"Embeddings index rebuilt with {len(tools)} tools")
def _ensure_index(self) -> None:
"""Ensure embeddings are up-to-date."""
if self._dirty or self._embeddings is None:
self.rebuild_index()
def search(self, query: str, top_k: int | None = None) -> list["ToolDefinition"]:
"""Search for tools semantically similar to the query.
Args:
query: Natural language search query
top_k: Maximum number of results (defaults to config)
Returns:
List of matching tools, ranked by semantic similarity
"""
if top_k is None:
top_k = config.MAX_SEARCH_RESULTS
self._ensure_index()
if self._embeddings is None or len(self._tool_names) == 0:
return []
model = _get_model()
# Encode query
query_embedding = model.encode(
query,
convert_to_numpy=True,
normalize_embeddings=True,
)
# Compute cosine similarities (dot product of normalized vectors)
similarities = np.dot(self._embeddings, query_embedding)
# Get top-k indices sorted by similarity (descending)
top_indices = np.argsort(similarities)[-top_k:][::-1]
# Get tool names for matches with positive similarity
matching_names = [
self._tool_names[idx]
for idx in top_indices
if similarities[idx] > 0
]
# Retrieve tool definitions
return self.catalog.get_tools_by_names(matching_names)
def search_names(self, query: str, top_k: int | None = None) -> list[str]:
"""Search and return only tool names.
Args:
query: Natural language search query
top_k: Maximum number of results
Returns:
List of matching tool names
"""
tools = self.search(query, top_k)
return [tool.name for tool in tools]
def search_with_scores(
self,
query: str,
top_k: int | None = None,
) -> list[tuple["ToolDefinition", float]]:
"""Search and return tools with similarity scores.
Args:
query: Natural language search query
top_k: Maximum number of results
Returns:
List of (tool, similarity_score) tuples
"""
if top_k is None:
top_k = config.MAX_SEARCH_RESULTS
self._ensure_index()
if self._embeddings is None or len(self._tool_names) == 0:
return []
model = _get_model()
# Encode query
query_embedding = model.encode(
query,
convert_to_numpy=True,
normalize_embeddings=True,
)
# Compute cosine similarities
similarities = np.dot(self._embeddings, query_embedding)
# Get top-k indices
top_indices = np.argsort(similarities)[-top_k:][::-1]
# Build results with scores
results = []
for idx in top_indices:
if similarities[idx] > 0:
tool = self.catalog.get_tool(self._tool_names[idx])
if tool:
results.append((tool, float(similarities[idx])))
return results