"""Unit tests for company search tool."""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from fastmcp import Context
from birre.tools.company_search import register_company_search_tool, company_search_impl
from birre.server import create_server
class TestCompanySearchTool:
"""Test cases for company search functionality."""
@pytest.fixture
def server(self):
"""Create a test server instance."""
return create_server()
@pytest.fixture
def mock_context(self):
"""Create a mock context for testing."""
context = Mock(spec=Context)
context.info = AsyncMock()
context.error = AsyncMock()
context.debug = AsyncMock()
context.warning = AsyncMock()
return context
@pytest.mark.asyncio
async def test_company_search_with_domain_success(self, server, mock_context):
"""Test successful company search with domain parameter."""
# Mock the bitsight library response
mock_search_results = [
{
"guid": "test-guid-1",
"name": "Example Corp",
"domain": "example.com",
"industry": "Technology",
},
{
"guid": "test-guid-2",
"name": "Example Inc",
"primary_domain": "example.org",
"industry": "Finance",
},
]
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
# Setup mock
mock_companies = Mock()
mock_companies.get_company_search.return_value = mock_search_results
mock_bitsight.Companies.return_value = mock_companies
# Execute the search directly
result = await company_search_impl(mock_context, domain="example.com")
# Verify the results
assert result["count"] == 2
assert len(result["companies"]) == 2
assert result["search_term"] == "example.com"
# Check first company
first_company = result["companies"][0]
assert first_company["guid"] == "test-guid-1"
assert first_company["name"] == "Example Corp"
assert first_company["domain"] == "example.com"
assert first_company["industry"] == "Technology"
# Check second company (uses primary_domain)
second_company = result["companies"][1]
assert second_company["guid"] == "test-guid-2"
assert second_company["domain"] == "example.org"
# Verify API was called correctly
mock_companies.get_company_search.assert_called_once_with("example.com")
# Verify context logging
mock_context.info.assert_any_call("Searching BitSight for: example.com")
mock_context.info.assert_any_call("Found 2 companies for: example.com")
@pytest.mark.asyncio
async def test_company_search_with_name_success(self, server, mock_context):
"""Test successful company search with name parameter."""
mock_search_results = [
{
"guid": "test-guid-1",
"name": "Test Company",
"domain": "test.com",
"industry": "Healthcare",
}
]
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
mock_companies = Mock()
mock_companies.get_company_search.return_value = mock_search_results
mock_bitsight.Companies.return_value = mock_companies
result = await company_search_impl(mock_context, name="Test Company")
assert result["count"] == 1
assert result["search_term"] == "Test Company"
assert result["companies"][0]["name"] == "Test Company"
# Verify API was called with name parameter
mock_companies.get_company_search.assert_called_once_with("Test Company")
@pytest.mark.asyncio
async def test_company_search_domain_precedence(self, server, mock_context):
"""Test that domain parameter takes precedence over name."""
mock_search_results = []
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
mock_companies = Mock()
mock_companies.get_company_search.return_value = mock_search_results
mock_bitsight.Companies.return_value = mock_companies
result = await company_search_impl(mock_context, **
{"name": "Company Name", "domain": "priority.com"}
)
# Should use domain, not name
assert result["search_term"] == "priority.com"
mock_companies.get_company_search.assert_called_once_with("priority.com")
@pytest.mark.asyncio
async def test_company_search_no_parameters(self, server, mock_context):
"""Test error when neither name nor domain is provided."""
result = await company_search_impl(mock_context)
assert "error" in result
assert (
"At least one of 'name' or 'domain' must be provided"
in result["error"]
)
assert result["count"] == 0
assert result["companies"] == []
@pytest.mark.asyncio
async def test_company_search_no_results(self, server, mock_context):
"""Test handling when no companies are found."""
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
mock_companies = Mock()
mock_companies.get_company_search.return_value = []
mock_bitsight.Companies.return_value = mock_companies
result = await company_search_impl(mock_context, domain="nonexistent.com")
assert result["count"] == 0
assert result["companies"] == []
assert result["search_term"] == "nonexistent.com"
mock_context.info.assert_any_call(
"No companies found for search term: nonexistent.com"
)
@pytest.mark.asyncio
async def test_company_search_filter_invalid_guids(self, server, mock_context):
"""Test that companies without valid GUIDs are filtered out."""
mock_search_results = [
{
"guid": "valid-guid",
"name": "Valid Company",
"domain": "valid.com",
"industry": "Technology",
},
{
"guid": None, # Invalid GUID
"name": "Invalid Company",
"domain": "invalid.com",
"industry": "Technology",
},
{
# Missing GUID entirely
"name": "No GUID Company",
"domain": "noguid.com",
"industry": "Technology",
},
]
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
mock_companies = Mock()
mock_companies.get_company_search.return_value = mock_search_results
mock_bitsight.Companies.return_value = mock_companies
result = await company_search_impl(mock_context, domain="test.com")
# Only the company with valid GUID should be included
assert result["count"] == 1
assert result["companies"][0]["guid"] == "valid-guid"
@pytest.mark.asyncio
async def test_company_search_api_exception(self, server, mock_context):
"""Test handling of BitSight API exceptions."""
with patch("birre.tools.company_search.bitsight") as mock_bitsight:
mock_companies = Mock()
mock_companies.get_company_search.side_effect = Exception("API Error")
mock_bitsight.Companies.return_value = mock_companies
result = await company_search_impl(mock_context, domain="error.com")
assert "error" in result
assert "Company search failed: API Error" in result["error"]
assert result["count"] == 0
assert result["companies"] == []
# Verify error was logged to context
mock_context.error.assert_called_once_with(
"Company search failed: API Error"
)