"""Company search tool for Birre."""
import logging
from typing import Dict, List, Optional, Any
from fastmcp import FastMCP, Context
import bitsight
try:
from ..config import get_config
except ImportError:
from birre.config import get_config
logger = logging.getLogger(__name__)
async def company_search_impl(
ctx: Context, name: Optional[str] = None, domain: Optional[str] = None
) -> Dict[str, Any]:
"""
Search for companies in BitSight by name or domain.
Args:
name: Company name to search for (optional)
domain: Company domain to search for (optional)
Returns:
Dictionary containing:
- companies: List of matching companies with guid, name, domain, industry
- count: Total number of matches
- search_term: The term used for searching
Note:
At least one of name or domain must be provided.
If both are provided, domain takes precedence.
"""
try:
# Validate inputs
if not name and not domain:
return {
"error": "At least one of 'name' or 'domain' must be provided",
"companies": [],
"count": 0,
}
# Initialize BitSight API client
companies_api = bitsight.Companies()
# Determine search term (domain takes precedence per requirements)
search_term = domain if domain else name
await ctx.info(f"Searching BitSight for: {search_term}")
# Perform search using domain-based search (BitSight API primary method)
# Note: The bitsight library primarily uses domain-based search
search_results = companies_api.get_company_search(search_term)
if not search_results:
await ctx.info(f"No companies found for search term: {search_term}")
return {"companies": [], "count": 0, "search_term": search_term}
# Process and structure the results
companies = []
for result in search_results:
company_data = {
"guid": result.get("guid"),
"name": result.get("name"),
"domain": result.get("domain") or result.get("primary_domain"),
"industry": result.get("industry"),
}
# Only include companies with valid GUID
if company_data["guid"]:
companies.append(company_data)
await ctx.info(f"Found {len(companies)} companies for: {search_term}")
return {
"companies": companies,
"count": len(companies),
"search_term": search_term,
}
except Exception as e:
error_msg = f"Company search failed: {str(e)}"
await ctx.error(error_msg)
logger.error(error_msg, exc_info=True)
return {
"error": error_msg,
"companies": [],
"count": 0,
"search_term": search_term if "search_term" in locals() else name or domain,
}
def register_company_search_tool(server: FastMCP) -> None:
"""Register the company search tool with the FastMCP server."""
@server.tool
async def company_search(
ctx: Context, name: Optional[str] = None, domain: Optional[str] = None
) -> Dict[str, Any]:
"""Search for companies in BitSight by name or domain."""
return await company_search_impl(ctx, name, domain)