from __future__ import annotations
import logging
from functools import partial
from typing import Any, Callable, Dict, Iterable, Mapping, Optional
from fastmcp import FastMCP
from src.settings import (
DEFAULT_MAX_FINDINGS,
DEFAULT_RISK_VECTOR_FILTER,
RuntimeSettings,
)
from src.logging import BoundLogger
_tool_logger = logging.getLogger("birre.tools")
from .apis import (
call_v1_openapi_tool,
call_v2_openapi_tool,
create_v1_api_server,
create_v2_api_server,
)
from .business import (
register_company_rating_tool,
register_company_search_tool,
)
INSTRUCTIONS_MAP: Dict[str, str] = {
"standard": (
"BitSight rating retriever. Use `company_search` to locate a company, "
"then call `get_company_rating` with the chosen GUID."
),
"risk_manager": (
"Risk manager persona. Start with `company_search_interactive` to review matches, "
"call `manage_subscriptions` to adjust coverage, and use `request_company` when an entity is missing."
),
}
def _require_api_key(settings: RuntimeSettings) -> str:
resolved_api_key = settings.api_key
if not resolved_api_key:
raise ValueError("Resolved settings must include a non-empty 'api_key'")
return str(resolved_api_key)
def _resolve_active_context(settings: RuntimeSettings) -> str:
return str(settings.context or "standard")
def _resolve_risk_vector_filter(settings: RuntimeSettings) -> str:
return str(settings.risk_vector_filter or DEFAULT_RISK_VECTOR_FILTER)
def _resolve_max_findings(settings: RuntimeSettings) -> int:
max_findings_value = settings.max_findings
if isinstance(max_findings_value, int) and max_findings_value > 0:
return max_findings_value
return DEFAULT_MAX_FINDINGS
def _resolve_tls_verification(
settings: RuntimeSettings, logger: BoundLogger
) -> bool | str:
allow_insecure_tls = bool(settings.allow_insecure_tls)
ca_bundle_path = settings.ca_bundle_path
verify_option: bool | str = True
if allow_insecure_tls:
logger.warning(
"tls.verify.disabled",
reason="allow_insecure_tls flag set",
)
return False
if ca_bundle_path:
verify_option = str(ca_bundle_path)
logger.info(
"tls.verify.custom_ca_bundle",
ca_bundle=verify_option,
)
return verify_option
def _maybe_create_v2_api_server(
active_context: str,
api_key: str,
verify_option: bool | str,
*,
base_url: Optional[str] = None,
) -> Optional[FastMCP]:
if active_context == "risk_manager":
kwargs: dict[str, Any] = {"verify": verify_option}
if base_url is not None:
kwargs["base_url"] = base_url
return create_v2_api_server(api_key, **kwargs)
return None
def _schedule_tool_disablement(api_server: FastMCP, keep: Iterable[str]) -> None:
"""Disable generated FastMCP tools not exposed by BiRRe.
FastMCP exposes no synchronous API for pruning tools, so we prefer the
manager's in-memory registry when available. If the internals are missing,
we fall back to a no-op and emit diagnostics instead of risking loop
teardown via ad-hoc asyncio usage.
"""
tool_manager = getattr(api_server, "_tool_manager", None)
if tool_manager is None:
_tool_logger.debug("tool_manager.missing server=%r", api_server)
return
tools = getattr(tool_manager, "_tools", None)
if not isinstance(tools, dict):
_tool_logger.debug(
"tool_registry.unexpected_shape registry_type=%s",
type(tools).__name__,
)
return
keep_set = set(keep)
for name, tool in tools.items():
if name in keep_set:
continue
try:
tool.disable()
except Exception as exc: # pragma: no cover - defensive
_tool_logger.debug(
"tool.disable_failed tool=%s error=%s",
name,
exc,
)
continue
def _configure_risk_manager_tools(
business_server: FastMCP,
settings: RuntimeSettings,
call_v1_tool: Callable[..., Any],
logger: BoundLogger,
resolved_api_key: str,
verify_option: bool | str,
max_findings: int,
) -> None:
from src.business.risk_manager import (
register_company_search_interactive_tool,
register_manage_subscriptions_tool,
register_request_company_tool,
)
register_company_search_tool(business_server, call_v1_tool, logger=logger)
call_v2_tool = getattr(business_server, "call_v2_tool", None)
if call_v2_tool is None:
call_v2_tool = partial(
call_v2_openapi_tool,
create_v2_api_server(resolved_api_key, verify=verify_option),
logger=logger,
)
setattr(business_server, "call_v2_tool", call_v2_tool)
default_folder = settings.subscription_folder
default_type = settings.subscription_type
register_company_search_interactive_tool(
business_server,
call_v1_tool,
logger=logger,
default_folder=default_folder,
default_type=default_type,
max_findings=max_findings,
)
register_manage_subscriptions_tool(
business_server,
call_v1_tool,
logger=logger,
default_folder=default_folder,
default_type=default_type,
)
register_request_company_tool(
business_server,
call_v1_tool,
call_v2_tool,
logger=logger,
default_folder=default_folder,
default_type=default_type,
)
def _configure_standard_tools(
business_server: FastMCP,
call_v1_tool: Callable[..., Any],
logger: BoundLogger,
) -> None:
register_company_search_tool(business_server, call_v1_tool, logger=logger)
def _coerce_runtime_settings(
settings: RuntimeSettings | Mapping[str, Any]
) -> RuntimeSettings:
if isinstance(settings, RuntimeSettings):
return settings
data = dict(settings)
max_findings = data.get("max_findings")
if not isinstance(max_findings, int) or max_findings <= 0:
max_findings = DEFAULT_MAX_FINDINGS
risk_vector_filter = data.get("risk_vector_filter") or DEFAULT_RISK_VECTOR_FILTER
warnings_raw = data.get("warnings", ())
if isinstance(warnings_raw, str):
warnings_tuple = (warnings_raw,)
else:
warnings_tuple = tuple(warnings_raw)
overrides_raw = data.get("overrides", ())
if isinstance(overrides_raw, str):
overrides_tuple = (overrides_raw,)
else:
overrides_tuple = tuple(overrides_raw)
raw_api_key = data.get("api_key")
if raw_api_key is None:
api_key = ""
elif isinstance(raw_api_key, str):
api_key = raw_api_key
else:
api_key = str(raw_api_key)
return RuntimeSettings(
api_key=api_key,
subscription_folder=data.get("subscription_folder"),
subscription_type=data.get("subscription_type"),
context=data.get("context"),
risk_vector_filter=risk_vector_filter,
max_findings=max_findings,
skip_startup_checks=bool(data.get("skip_startup_checks", False)),
debug=bool(data.get("debug", False)),
allow_insecure_tls=bool(data.get("allow_insecure_tls", False)),
ca_bundle_path=data.get("ca_bundle_path"),
warnings=warnings_tuple,
overrides=overrides_tuple,
)
def create_birre_server(
settings: RuntimeSettings | Mapping[str, Any],
logger: BoundLogger,
*,
v1_base_url: Optional[str] = None,
v2_base_url: Optional[str] = None,
) -> FastMCP:
"""Create and configure the BiRRe FastMCP business server using resolved settings."""
resolved_settings = _coerce_runtime_settings(settings)
resolved_api_key = _require_api_key(resolved_settings)
active_context = _resolve_active_context(resolved_settings)
risk_vector_filter = _resolve_risk_vector_filter(resolved_settings)
max_findings = _resolve_max_findings(resolved_settings)
verify_option = _resolve_tls_verification(resolved_settings, logger)
v1_kwargs: dict[str, Any] = {"verify": verify_option}
if v1_base_url is not None:
v1_kwargs["base_url"] = v1_base_url
v1_api_server = create_v1_api_server(resolved_api_key, **v1_kwargs)
v2_api_server = _maybe_create_v2_api_server(
active_context,
resolved_api_key,
verify_option,
base_url=v2_base_url,
)
business_server = FastMCP(
name="io.github.boecht.birre",
instructions=INSTRUCTIONS_MAP.get(active_context, INSTRUCTIONS_MAP["standard"]),
)
call_v1_tool = partial(call_v1_openapi_tool, v1_api_server, logger=logger)
setattr(business_server, "call_v1_tool", call_v1_tool)
if v2_api_server is not None:
call_v2_tool = partial(call_v2_openapi_tool, v2_api_server, logger=logger)
setattr(business_server, "call_v2_tool", call_v2_tool)
_schedule_tool_disablement(
v1_api_server,
{
"companySearch",
"manageSubscriptionsBulk",
"getCompany",
"getCompaniesFindings",
"getFolders",
"getCompanySubscriptions",
},
)
if v2_api_server is not None:
_schedule_tool_disablement(
v2_api_server,
{
"getCompanyRequests",
"createCompanyRequest",
"createCompanyRequestBulk",
},
)
register_company_rating_tool(
business_server,
call_v1_tool,
logger=logger,
risk_vector_filter=risk_vector_filter,
max_findings=max_findings,
default_folder=resolved_settings.subscription_folder,
default_type=resolved_settings.subscription_type,
debug_enabled=bool(resolved_settings.debug),
)
if active_context == "risk_manager":
_configure_risk_manager_tools(
business_server,
resolved_settings,
call_v1_tool,
logger,
resolved_api_key,
verify_option,
max_findings,
)
else:
_configure_standard_tools(business_server, call_v1_tool, logger)
return business_server
__all__ = [
"create_birre_server",
]