"""
Tool registry with context-aware policy enforcement.
This module provides a wrapper around FastMCP's tool registration system
that adds context policy enforcement.
"""
from __future__ import annotations
import functools
from typing import Any, Callable, Dict, List, Optional, Set, Type
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from src.core.tool_context import PageContext, ToolContextPolicy
from src.observability import get_logger
logger = get_logger(__name__)
class ContextAwareToolRegistry:
"""
Registry for tools with context policy enforcement.
This wraps FastMCP's tool registration to add:
- Context policy metadata (required)
- Runtime policy enforcement (fail-closed)
- Tool filtering by context
- Entity ID binding
- Strict validation
"""
def __init__(self, mcp: FastMCP):
self.mcp = mcp
self._policies: Dict[str, ToolContextPolicy] = {}
self._startup_complete = False
def register_tool(
self,
name: str,
policy: ToolContextPolicy,
description: Optional[str] = None,
input_schema: Optional[Type[BaseModel]] = None,
entity_param: Optional[str] = None
) -> Callable:
"""
Register a tool with context policy metadata.
Args:
name: Tool name
policy: Tool's context policy
description: Tool description
input_schema: Optional Pydantic model for input validation
Returns:
Decorator for the tool function
"""
def decorator(func: Callable) -> Callable:
if self._startup_complete:
raise RuntimeError(
f"Tool {name} registered after startup. "
"All tools must be registered during initialization."
)
# Validate and store the policy
if not policy:
raise ValueError(f"Tool {name} must have a ToolContextPolicy")
self._policies[name] = policy
# Register with FastMCP (only pass supported parameters)
@self.mcp.tool(name=name, description=description)
@functools.wraps(func)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
# Context validation and filtering happens in orchestrator
# before tools are called. No need to re-validate here.
# This allows MCP server to run in separate process.
logger.debug(
"tool_executing",
tool_name=name,
policy=policy.model_dump() if policy else None
)
return await func(*args, **kwargs)
return wrapped
return decorator
def get_allowed_tools(self, context: PageContext) -> List[str]:
"""
Get tools allowed in the given context.
Args:
context: Current UI context
Returns:
List of tool names that can run in this context
"""
allowed = []
for name, policy in self._policies.items():
if context.matches_policy(policy):
allowed.append(name)
return allowed
def can_run_tool(self, name: str, context: PageContext) -> bool:
"""
Check if a specific tool can run in the given context.
Args:
name: Tool name
context: Current UI context
Returns:
Whether the tool is allowed to run
"""
policy = self._policies.get(name)
if not policy:
logger.warning("tool_not_registered", tool_name=name)
return False
return context.matches_policy(policy)
@property
def registered_tools(self) -> Set[str]:
"""Get names of all registered tools."""
return set(self._policies.keys())
def complete_startup(self) -> None:
"""Mark registry as complete, preventing further registrations."""
self._startup_complete = True
logger.info(
"tool_registry_locked",
tool_count=len(self._policies)
)
def get_tool_policy(self, name: str) -> Optional[ToolContextPolicy]:
"""Get policy for a specific tool."""
return self._policies.get(name)