"""Tool catalog for storing and managing tool definitions."""
from typing import Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
import threading
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class InputSchema(BaseModel):
"""JSON Schema for tool input parameters."""
type: str = "object"
properties: dict[str, Any] = Field(default_factory=dict)
required: list[str] = Field(default_factory=list)
class ToolDefinition(BaseModel):
"""Definition of a tool that can be discovered and used.
Matches the Anthropic API tool format for compatibility.
"""
name: str = Field(..., description="Unique name of the tool")
description: str = Field(..., description="Description of what the tool does")
input_schema: InputSchema = Field(..., description="JSON Schema for tool inputs")
defer_loading: bool = Field(default=True, description="Whether to defer loading this tool")
# Metadata
created_at: datetime = Field(default_factory=datetime.utcnow)
tags: list[str] = Field(default_factory=list, description="Tags for categorization")
def to_searchable_text(self) -> str:
"""Convert tool to searchable text for indexing.
Includes name, description, and parameter information.
"""
text_parts = [
f"Tool: {self.name}",
f"Description: {self.description}",
]
# Add parameter information
for param_name, param_info in self.input_schema.properties.items():
param_desc = param_info.get("description", "")
param_type = param_info.get("type", "any")
text_parts.append(f"Parameter {param_name} ({param_type}): {param_desc}")
# Add tags
if self.tags:
text_parts.append(f"Tags: {', '.join(self.tags)}")
return "\n".join(text_parts)
def to_api_format(self) -> dict[str, Any]:
"""Convert to Anthropic API tool format."""
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema.model_dump(),
"defer_loading": self.defer_loading,
}
class ToolReference(BaseModel):
"""Reference to a tool, returned by search operations."""
type: str = "tool_reference"
tool_name: str
def to_dict(self) -> dict[str, str]:
"""Convert to dictionary format for API response."""
return {"type": self.type, "tool_name": self.tool_name}
@dataclass
class ToolCatalog:
"""In-memory catalog for storing and searching tool definitions.
Thread-safe implementation with lazy index rebuilding.
"""
tools: dict[str, ToolDefinition] = field(default_factory=dict)
_lock: threading.RLock = field(default_factory=threading.RLock)
_index_dirty: bool = field(default=False)
# Callbacks for search engines to rebuild indexes
_on_update_callbacks: list = field(default_factory=list)
def register_tool(self, tool: ToolDefinition) -> None:
"""Register a new tool or update an existing one.
Args:
tool: The tool definition to register
"""
with self._lock:
is_update = tool.name in self.tools
self.tools[tool.name] = tool
self._index_dirty = True
logger.info(f"{'Updated' if is_update else 'Registered'} tool: {tool.name}")
self._notify_update()
def register_tools(self, tools: list[ToolDefinition]) -> None:
"""Register multiple tools at once.
Args:
tools: List of tool definitions to register
"""
with self._lock:
for tool in tools:
self.tools[tool.name] = tool
self._index_dirty = True
logger.info(f"Registered {len(tools)} tools")
self._notify_update()
def remove_tool(self, name: str) -> bool:
"""Remove a tool from the catalog.
Args:
name: Name of the tool to remove
Returns:
True if the tool was removed, False if not found
"""
with self._lock:
if name in self.tools:
del self.tools[name]
self._index_dirty = True
logger.info(f"Removed tool: {name}")
self._notify_update()
return True
return False
def get_tool(self, name: str) -> Optional[ToolDefinition]:
"""Get a tool by name.
Args:
name: Name of the tool to retrieve
Returns:
The tool definition or None if not found
"""
with self._lock:
return self.tools.get(name)
def list_tools(self) -> list[ToolDefinition]:
"""List all tools in the catalog.
Returns:
List of all tool definitions
"""
with self._lock:
return list(self.tools.values())
def get_tool_names(self) -> list[str]:
"""Get all tool names.
Returns:
List of tool names
"""
with self._lock:
return list(self.tools.keys())
def get_tools_by_names(self, names: list[str]) -> list[ToolDefinition]:
"""Get multiple tools by their names.
Args:
names: List of tool names to retrieve
Returns:
List of found tool definitions (missing tools are skipped)
"""
with self._lock:
return [self.tools[name] for name in names if name in self.tools]
def count(self) -> int:
"""Get the number of tools in the catalog."""
with self._lock:
return len(self.tools)
def clear(self) -> None:
"""Remove all tools from the catalog."""
with self._lock:
self.tools.clear()
self._index_dirty = True
logger.info("Cleared all tools from catalog")
self._notify_update()
def on_update(self, callback) -> None:
"""Register a callback to be called when the catalog is updated.
Args:
callback: Function to call on updates (receives the catalog)
"""
self._on_update_callbacks.append(callback)
def _notify_update(self) -> None:
"""Notify all registered callbacks of an update."""
for callback in self._on_update_callbacks:
try:
callback(self)
except Exception as e:
logger.error(f"Error in catalog update callback: {e}")
def is_index_dirty(self) -> bool:
"""Check if the index needs to be rebuilt."""
return self._index_dirty
def mark_index_clean(self) -> None:
"""Mark the index as up-to-date."""
self._index_dirty = False
# Global catalog instance
catalog = ToolCatalog()