"""Tests for search engines."""
import pytest
from src.catalog import ToolCatalog, ToolDefinition, InputSchema
from src.search.bm25 import BM25Search
from src.search.regex import RegexSearch, RegexSearchError
from src.search.embeddings import EmbeddingsSearch
@pytest.fixture
def catalog_with_tools():
"""Create a catalog with sample tools."""
catalog = ToolCatalog()
tools = [
ToolDefinition(
name="get_weather",
description="Get the current weather conditions for a specific location",
input_schema=InputSchema(
properties={"location": {"type": "string"}},
required=["location"]
),
tags=["weather"]
),
ToolDefinition(
name="get_forecast",
description="Get weather forecast for the next several days",
input_schema=InputSchema(
properties={"location": {"type": "string"}, "days": {"type": "integer"}},
required=["location"]
),
tags=["weather", "forecast"]
),
ToolDefinition(
name="send_email",
description="Send an email message to recipients",
input_schema=InputSchema(
properties={"to": {"type": "array"}, "subject": {"type": "string"}},
required=["to", "subject"]
),
tags=["email", "communication"]
),
ToolDefinition(
name="query_database",
description="Execute a SQL query against the database",
input_schema=InputSchema(
properties={"query": {"type": "string"}},
required=["query"]
),
tags=["database", "sql"]
),
]
catalog.register_tools(tools)
return catalog
class TestBM25Search:
"""Tests for BM25 search engine."""
def test_search_finds_matching_tools(self, catalog_with_tools):
"""Test that BM25 finds tools matching keywords."""
search = BM25Search(catalog_with_tools)
results = search.search("weather")
assert len(results) >= 1
assert any(t.name == "get_weather" for t in results)
def test_search_ranks_by_relevance(self, catalog_with_tools):
"""Test that more relevant tools are ranked higher."""
search = BM25Search(catalog_with_tools)
results = search.search("weather forecast")
# get_forecast should rank high due to both keywords
names = [t.name for t in results]
assert "get_forecast" in names
def test_search_returns_empty_for_no_match(self, catalog_with_tools):
"""Test that unmatched queries return empty."""
search = BM25Search(catalog_with_tools)
results = search.search("xyznonexistent123")
assert len(results) == 0
def test_search_respects_top_k(self, catalog_with_tools):
"""Test that top_k limits results."""
search = BM25Search(catalog_with_tools)
results = search.search("send email database query", top_k=2)
assert len(results) <= 2
class TestRegexSearch:
"""Tests for regex search engine."""
def test_search_matches_name(self, catalog_with_tools):
"""Test regex matching against tool names."""
search = RegexSearch(catalog_with_tools)
results = search.search("get_.*")
assert len(results) == 2
names = [t.name for t in results]
assert "get_weather" in names
assert "get_forecast" in names
def test_search_matches_description(self, catalog_with_tools):
"""Test regex matching against descriptions."""
search = RegexSearch(catalog_with_tools)
results = search.search("SQL")
assert len(results) == 1
assert results[0].name == "query_database"
def test_search_case_insensitive(self, catalog_with_tools):
"""Test case-insensitive matching."""
search = RegexSearch(catalog_with_tools)
results = search.search("(?i)sql")
assert len(results) == 1
assert results[0].name == "query_database"
def test_search_invalid_pattern_raises(self, catalog_with_tools):
"""Test that invalid regex raises error."""
search = RegexSearch(catalog_with_tools)
with pytest.raises(RegexSearchError) as exc_info:
search.search("[invalid")
assert exc_info.value.error_code == "invalid_pattern"
def test_search_pattern_too_long_raises(self, catalog_with_tools):
"""Test that overly long patterns raise error."""
search = RegexSearch(catalog_with_tools)
long_pattern = "a" * 201
with pytest.raises(RegexSearchError) as exc_info:
search.search(long_pattern)
assert exc_info.value.error_code == "pattern_too_long"
class TestEmbeddingsSearch:
"""Tests for embeddings-based search engine."""
def test_search_finds_semantic_matches(self, catalog_with_tools):
"""Test that semantic search finds conceptually related tools."""
search = EmbeddingsSearch(catalog_with_tools)
results = search.search("check the temperature outside")
# Should find weather-related tools
names = [t.name for t in results]
assert any("weather" in name for name in names)
def test_search_with_scores_returns_scores(self, catalog_with_tools):
"""Test that search_with_scores returns similarity scores."""
search = EmbeddingsSearch(catalog_with_tools)
results = search.search_with_scores("send a message")
assert len(results) > 0
for tool, score in results:
assert isinstance(score, float)
assert 0 <= score <= 1
def test_search_respects_top_k(self, catalog_with_tools):
"""Test that top_k limits results."""
search = EmbeddingsSearch(catalog_with_tools)
results = search.search("tools for data", top_k=2)
assert len(results) <= 2