import inspect
import json
import logging
import time
from functools import wraps
from typing import Any, Callable, Mapping, Optional, Type, TypeVar, cast
import yaml
from fastmcp import Context
from fastmcp.exceptions import ToolError
from fastmcp.server import middleware as fmw
from fastmcp.server.middleware import CallNext, MiddlewareContext
from fastmcp.utilities.types import find_kwarg_by_type
from mcp import types as mt
from pydantic import BaseModel, ValidationError
from pydantic_core import ErrorDetails
from keboola_mcp_server.clients.client import KeboolaClient
from keboola_mcp_server.clients.storage import StorageEventType
from keboola_mcp_server.mcp import CONVERSATION_ID, ServerState, get_http_request_or_none
LOG = logging.getLogger(__name__)
F = TypeVar('F', bound=Callable[..., Any])
_USER_AGENT_TO_COMPONENT_ID: Mapping[str, str] = {
'read-only-chat': 'keboola.ai-chat',
'in-platform-chat': 'keboola.kai-assistant',
}
class _JsonWrapper(BaseModel):
"""
Utility class for safely encoding arbitrary Python objects to JSON strings.
Uses Pydantic's serialization to handle complex objects as well as simple types like int, float, bool, str, etc.
Primary use case is serializing tool function parameters for Keboola Storage API events.
"""
data: Any # The arbitrary object to be JSON serialized
@classmethod
def encode(cls, obj: Any) -> str:
return json.dumps(cls(data=obj).model_dump()['data'], ensure_ascii=False)
async def _trigger_event(
func: Callable, args: tuple, kwargs: dict, exception: Exception | None, execution_time: float
) -> None:
# TODO: This is not always correct. In general tool functions can be registered
# in the MCP server under different names.
tool_name = func.__name__
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
ctx_param_name = find_kwarg_by_type(func, Context)
assert ctx_param_name, f'The tool function {tool_name} must have a "Context" parameter.'
ctx = bound_args.arguments.get(ctx_param_name)
assert isinstance(ctx, Context), (
f'The tool function {tool_name} has invalid "{ctx_param_name}" parameter. '
f'Expecting instance of "Context", got {type(ctx)}.'
)
runtime_info = ServerState.from_context(ctx).runtime_info
user_agent: str | None = None
if client_params := ctx.session.client_params:
user_agent = f'{client_params.clientInfo.name}/{client_params.clientInfo.version}'
if not user_agent:
user_agent = ctx.client_id
if not user_agent:
if http_rq := get_http_request_or_none():
user_agent = http_rq.headers.get('User-Agent')
if not user_agent:
user_agent = ''
# See # https://github.com/keboola/event-schema/blob/main/schema/ext.keboola.mcp-server-tool.json
# for the JSON schema describing the 'keboola.mcp-server-tool' component's event params.
event_params: dict[str, Any] = {
'mcpServerContext': {
'appEnv': runtime_info.app_version,
'version': runtime_info.server_version,
'userAgent': user_agent,
# For the HTTP-based transports use the HTTP session ID. For other transports use the server ID.
'sessionId': ctx.session_id or runtime_info.server_id,
'serverTransport': runtime_info.transport.split('/')[-1],
'conversationId': ctx.session.state.get(CONVERSATION_ID) or '',
},
'tool': {
'name': tool_name,
'arguments': [
{'key': param_name, 'value': _JsonWrapper.encode(param_value)}
for param_name, param_value in bound_args.arguments.items()
if param_name not in [ctx_param_name, 'self', 'cls']
],
},
}
if exception:
message = f'MCP tool "{tool_name}" call failed. {type(exception).__name__}: {exception}'
event_type: StorageEventType = 'error'
else:
message = f'MCP tool "{tool_name}" call succeeded.'
event_type: StorageEventType = 'success'
client = KeboolaClient.from_state(ctx.session.state)
resp = await client.storage_client.trigger_event(
message=message,
component_id=_USER_AGENT_TO_COMPONENT_ID.get(user_agent.split(sep='/')[0]) or 'keboola.mcp-server-tool',
event_type=event_type,
params=event_params,
duration=execution_time,
)
LOG.debug(f'Tool call SAPI event triggered: {resp}')
def tool_errors(
default_recovery: Optional[str] = None,
recovery_instructions: Optional[dict[Type[Exception], str]] = None,
) -> Callable[[F], F]:
"""
The MCP tool function decorator that logs exceptions and adds recovery instructions for LLMs.
:param default_recovery: A fallback recovery instruction to use when no specific instruction
is found for the exception.
:param recovery_instructions: A dictionary mapping exception types to recovery instructions.
:return: The decorated function with error-handling logic applied.
"""
def decorator(func: Callable):
@wraps(func)
async def wrapped(*args, **kwargs):
exception: Exception | None = None
start = time.perf_counter()
try:
return await func(*args, **kwargs)
except Exception as e:
recovery_msg = default_recovery
if recovery_instructions:
for exc_type, msg in recovery_instructions.items():
if isinstance(e, exc_type):
recovery_msg = msg
break
error_msg: str | None = None
if isinstance(e, ValidationError):
error_msg = prettify_validation_error(e)
if recovery_msg:
error_msg += f'\nRecovery: {recovery_msg}'
elif recovery_msg:
error_msg = f'{e}\nRecovery: {recovery_msg}'
try:
if error_msg:
raise ToolError(error_msg) from e
else:
raise e
except Exception as e:
LOG.exception(f'MCP tool "{func.__name__}" call failed. {type(e).__name__}: {e}')
exception = e
raise
finally:
try:
await _trigger_event(func, args, kwargs, exception, time.perf_counter() - start)
except Exception as e:
LOG.exception(f'Failed to trigger tool event for "{func.__name__}" tool: {e}')
raise
return cast(F, wrapped)
return decorator
def _format_validation_errors(errors: list[ErrorDetails]) -> dict[str, Any]:
"""
Formats Pydantic validation errors into a structured dictionary.
:param errors: List of error dictionaries from ValidationError.errors()
:return: Dictionary with formatted errors including field, message, and extra fields
"""
formatted_errors: list[dict[str, Any]] = []
for error in errors:
error_dict: dict[str, Any] = {
'field': '.'.join(str(i) for i in error.get('loc', [])),
'message': error.get('msg', 'Validation error'),
'extra': {str(key): str(value) for key, value in error.items() if key not in {'loc', 'msg'}},
}
formatted_errors.append(error_dict)
return {'errors': formatted_errors}
def prettify_validation_error(error: ValidationError) -> str:
"""
Formats a Pydantic ValidationError into a human and LLM-readable YAML string.
:param error: The Pydantic ValidationError to format
:return: A formatted YAML string with error details
"""
error_count = len(error.errors())
model_name = getattr(error, 'title', 'unknown')
header = f'Found {error_count} validation error(s) for {model_name}'
formatted = _format_validation_errors(error.errors())
try:
yaml_str = yaml.dump(formatted, default_flow_style=False, sort_keys=False, allow_unicode=True)
except Exception:
yaml_str = str(formatted)
return f'{header}\n{yaml_str}'
class ValidationErrorMiddleware(fmw.Middleware):
"""
Middleware that catches Pydantic ValidationError and formats it with explicit field locations.
This middleware intercepts tool calls and catches any Pydantic ValidationError that occurs
during argument validation. It then formats the error message to clearly show which fields
are missing or invalid, making it easier for both humans and LLMs to understand the issue.
"""
async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, mt.CallToolResult],
) -> mt.CallToolResult:
try:
return await call_next(context)
except ValidationError as e:
raise ToolError(prettify_validation_error(e)) from e