_registration.py•10.9 kB
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from contextlib import suppress
import functools
import inspect
import logging
import sys
import types
import uuid
from typing import Annotated, Any, Union, cast, get_args, get_origin, get_type_hints
from mcp.server.fastmcp import FastMCP, Context as MCPContext
from mcp.types import ToolAnnotations
from schwab_mcp.context import SchwabContext
from schwab_mcp.approvals import ApprovalDecision, ApprovalRequest
logger = logging.getLogger(__name__)
ToolFn = Callable[..., Awaitable[Any]]
_APPROVAL_PROGRESS_INTERVAL = 5.0
_APPROVAL_WAIT_MESSAGE = "Waiting for reviewer approval…"
def _is_context_annotation(annotation: Any) -> bool:
    if annotation in (inspect._empty, None):
        return False
    if annotation is SchwabContext:
        return True
    if annotation == "SchwabContext":
        return True
    if isinstance(annotation, str):
        return annotation == "SchwabContext"
    origin = get_origin(annotation)
    if origin is None:
        return False
    if origin in (Annotated,):
        args = get_args(annotation)
        return bool(args) and _is_context_annotation(args[0])
    if origin in (Union, types.UnionType):
        return any(_is_context_annotation(arg) for arg in get_args(annotation))
    return False
def _resolve_context_parameters(func: ToolFn) -> tuple[inspect.Signature, list[str]]:
    signature = inspect.signature(func)
    module = sys.modules.get(func.__module__)
    globalns = vars(module) if module else {}
    type_hints: dict[str, Any]
    try:
        type_hints = get_type_hints(func, globalns=globalns, include_extras=True)
    except TypeError:
        type_hints = get_type_hints(func, globalns=globalns)
    except Exception:
        type_hints = {}
    ctx_params = []
    for name, param in signature.parameters.items():
        annotation = type_hints.get(name, param.annotation)
        if _is_context_annotation(annotation):
            ctx_params.append(name)
    return signature, ctx_params
def _ensure_schwab_context(func: ToolFn) -> ToolFn:
    signature, ctx_params = _resolve_context_parameters(func)
    if not ctx_params:
        return func
    @functools.wraps(func)
    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        bound = signature.bind_partial(*args, **kwargs)
        for name in ctx_params:
            if name not in bound.arguments:
                continue
            value = bound.arguments[name]
            if isinstance(value, SchwabContext):
                continue
            if isinstance(value, MCPContext):
                bound.arguments[name] = SchwabContext.model_construct(
                    _request_context=value.request_context,
                    _fastmcp=getattr(value, "_fastmcp", None),
                )
            else:
                raise TypeError(
                    f"Argument '{name}' must be an MCP context, got {type(value)!r}"
                )
        result = func(*bound.args, **bound.kwargs)
        if inspect.isawaitable(result):
            return await result
        return result
    # Ensure annotations referencing names from the original module remain resolvable.
    wrapper_globals = cast(dict[str, Any], getattr(wrapper, "__globals__", {}))
    module = inspect.getmodule(func)
    if module is not None:
        module_globals = vars(module)
        if wrapper_globals is not module_globals:
            for key, value in module_globals.items():
                wrapper_globals.setdefault(key, value)
    return wrapper
def _format_argument(value: Any) -> str:
    text = repr(value)
    if len(text) > 256:
        return f"{text[:253]}..."
    return text
def _wrap_with_approval(func: ToolFn) -> ToolFn:
    signature, ctx_params = _resolve_context_parameters(func)
    if not ctx_params:
        raise TypeError(
            f"Write tool '{func.__name__}' must accept a SchwabContext parameter for approval gating."
        )
    @functools.wraps(func)
    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        bound = signature.bind_partial(*args, **kwargs)
        context: SchwabContext | None = None
        for name in ctx_params:
            if name not in bound.arguments:
                continue
            value = bound.arguments[name]
            if isinstance(value, SchwabContext):
                context = value
                continue
            if isinstance(value, MCPContext):
                converted = SchwabContext.model_construct(
                    _request_context=value.request_context,
                    _fastmcp=getattr(value, "_fastmcp", None),
                )
                bound.arguments[name] = converted
                context = converted
                continue
        if context is None:
            raise RuntimeError(
                f"Write tool '{func.__name__}' missing SchwabContext during invocation."
            )
        arguments = {
            name: _format_argument(arg)
            for name, arg in bound.arguments.items()
            if name not in ctx_params
        }
        request = ApprovalRequest(
            id=str(uuid.uuid4()),
            tool_name=func.__name__,
            request_id=context.request_id,
            client_id=context.client_id,
            arguments=arguments,
        )
        if _has_progress_token(context):
            await context.report_progress(0, 1, _APPROVAL_WAIT_MESSAGE)
        keepalive_task = _start_approval_keepalive(context)
        try:
            decision = await context.approvals.require(request)
        finally:
            if keepalive_task is not None:
                keepalive_task.cancel()
                with suppress(asyncio.CancelledError):
                    await keepalive_task
        await _report_approval_completion(context, decision)
        logger.info(
            "Approval decision %s for tool '%s' (approval_id=%s, client_id=%s, request_id=%s)",
            decision.value,
            func.__name__,
            request.id,
            request.client_id or "<unknown>",
            request.request_id,
        )
        if decision is ApprovalDecision.APPROVED:
            result = func(*bound.args, **bound.kwargs)
            if inspect.isawaitable(result):
                return await result
            return result
        message = (
            f"Write operation for tool '{func.__name__}' denied by reviewer."
            if decision is ApprovalDecision.DENIED
            else f"Approval request for tool '{func.__name__}' expired."
        )
        await context.warning(message)
        if decision is ApprovalDecision.DENIED:
            raise PermissionError(message)
        raise TimeoutError(message)
    wrapper_globals = cast(dict[str, Any], getattr(wrapper, "__globals__", {}))
    module = inspect.getmodule(func)
    if module is not None:
        module_globals = vars(module)
        if wrapper_globals is not module_globals:
            for key, value in module_globals.items():
                wrapper_globals.setdefault(key, value)
    return wrapper
def _start_approval_keepalive(context: SchwabContext) -> asyncio.Task[None] | None:
    if not _has_progress_token(context):
        return None
    async def _keepalive() -> None:
        elapsed = 0.0
        try:
            while True:
                await asyncio.sleep(_APPROVAL_PROGRESS_INTERVAL)
                elapsed_line = int(elapsed + _APPROVAL_PROGRESS_INTERVAL)
                await context.report_progress(
                    0,
                    1,
                    f"{_APPROVAL_WAIT_MESSAGE} ({elapsed_line}s elapsed)",
                )
                elapsed += _APPROVAL_PROGRESS_INTERVAL
        except asyncio.CancelledError:
            raise
        except Exception:  # pragma: no cover - best effort keepalive
            logger.debug("Failed to send approval progress keepalive", exc_info=True)
    return asyncio.create_task(_keepalive())
async def _report_approval_completion(
    context: SchwabContext, decision: ApprovalDecision
) -> None:
    if not _has_progress_token(context):
        return
    message = (
        "Reviewer approved the request."
        if decision is ApprovalDecision.APPROVED
        else "Approval flow finished without approval."
    )
    try:
        await context.report_progress(1, 1, message)
    except Exception:  # pragma: no cover - best effort completion
        logger.debug("Failed to send approval completion progress", exc_info=True)
def _has_progress_token(context: SchwabContext) -> bool:
    try:
        progress_token = getattr(context.request_context.meta, "progressToken", None)
    except ValueError:
        return False
    return bool(progress_token)
def _wrap_result_transform(
    func: ToolFn, transform: Callable[[Any], Any]
) -> ToolFn:
    @functools.wraps(func)
    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        result = func(*args, **kwargs)
        if inspect.isawaitable(result):
            result = await result
        return transform(result)
    # Preserve global namespace similar to other wrappers
    wrapper_globals = cast(dict[str, Any], getattr(wrapper, "__globals__", {}))
    module = inspect.getmodule(func)
    if module is not None:
        module_globals = vars(module)
        if wrapper_globals is not module_globals:
            for key, value in module_globals.items():
                wrapper_globals.setdefault(key, value)
    return cast(ToolFn, wrapper)
def register_tool(
    server: FastMCP,
    func: ToolFn,
    *,
    write: bool = False,
    annotations: ToolAnnotations | None = None,
    result_transform: Callable[[Any], Any] | None = None,
) -> None:
    """Register a Schwab tool using FastMCP's decorator plumbing."""
    func = _ensure_schwab_context(func)
    if write:
        func = _wrap_with_approval(func)
    if result_transform is not None:
        func = _wrap_result_transform(func, result_transform)
    tool_annotations = annotations
    if tool_annotations is None:
        if write:
            tool_annotations = ToolAnnotations(
                readOnlyHint=False,
                destructiveHint=True,
            )
        else:
            tool_annotations = ToolAnnotations(
                readOnlyHint=True,
            )
    else:
        update: dict[str, Any] = {}
        if tool_annotations.readOnlyHint is None:
            update["readOnlyHint"] = not write
        if write and tool_annotations.destructiveHint is None:
            update["destructiveHint"] = True
        if update:
            tool_annotations = tool_annotations.model_copy(update=update)
    server.tool(
        name=func.__name__,
        description=func.__doc__,
        annotations=tool_annotations,
    )(func)
__all__ = ["register_tool"]