Skip to main content
Glama
tools.py6.34 kB
import logging from collections.abc import Sequence from contextlib import AsyncExitStack from typing import ( Annotated, Any, cast, ) from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.tools.base import Tool as InternalTool from mcp.server.fastmcp.utilities.func_metadata import ( ArgModelBase, FuncMetadata, _get_typed_annotation, ) from mcp.shared.message import SessionMessage from mcp.types import ( ContentBlock, Tool, ) from pydantic import Field, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined from dbt_mcp.config.config_providers import ConfigProvider, ProxiedToolConfig from dbt_mcp.errors import RemoteToolError from dbt_mcp.tools.tool_names import ToolName from dbt_mcp.tools.toolsets import Toolset, proxied_tools, toolsets logger = logging.getLogger(__name__) # Based on this: https://github.com/modelcontextprotocol/python-sdk/blob/9ae4df85fbab97bf476ddd160b766ca4c208cd13/src/mcp/server/fastmcp/utilities/func_metadata.py#L105 def get_remote_tool_fn_metadata(tool: Tool) -> FuncMetadata: dynamic_pydantic_model_params: dict[str, Any] = {} for key in tool.inputSchema["properties"]: # Remote tools shouldn't have type annotations or default values # for their arguments. So, we set them to defaults. field_info = FieldInfo.from_annotated_attribute( annotation=_get_typed_annotation( annotation=Annotated[ Any, Field(), WithJsonSchema({"title": key, "type": "string"}), ], globalns={}, ), default=PydanticUndefined, ) dynamic_pydantic_model_params[key] = (field_info.annotation, None) return FuncMetadata( arg_model=create_model( f"{tool.name}Arguments", **dynamic_pydantic_model_params, __base__=ArgModelBase, ) ) async def get_proxied_tools( session: ClientSession, configured_proxied_tools: set[ToolName], ) -> list[Tool]: tools = (await session.list_tools()).tools normalized_configured_proxied_tools = { t.value.lower() for t in configured_proxied_tools } return [t for t in tools if t.name.lower() in normalized_configured_proxied_tools] class ProxiedToolsManager: _stack = AsyncExitStack() async def get_remote_mcp_session( self, url: str, headers: dict[str, str] ) -> ClientSession: streamablehttp_client_context: tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback, ] = await self._stack.enter_async_context( streamablehttp_client( url=url, headers=headers, ) ) read_stream, write_stream, _ = streamablehttp_client_context return await self._stack.enter_async_context( ClientSession(read_stream, write_stream) ) @classmethod async def close(cls) -> None: await cls._stack.aclose() def resolve_proxied_tools_configuration( config: ProxiedToolConfig, exclude_tools: Sequence[ToolName], ) -> set[ToolName]: configured_proxied_tools = cast(set[ToolName], proxied_tools) - set(exclude_tools) if config.are_sql_tools_disabled: configured_proxied_tools = configured_proxied_tools - toolsets[Toolset.SQL] if config.are_discovery_tools_disabled: configured_proxied_tools = ( configured_proxied_tools - toolsets[Toolset.DISCOVERY] ) return configured_proxied_tools async def register_proxied_tools( dbt_mcp: FastMCP, config_provider: ConfigProvider[ProxiedToolConfig], exclude_tools: Sequence[ToolName] = [], ) -> None: """ Register proxied MCP tools. Proxied tools are hosted remotely, so their definitions aren't found in this repo. """ config = await config_provider.get_config() configured_proxied_tools = resolve_proxied_tools_configuration( config, exclude_tools ) if not configured_proxied_tools: return headers = config.headers_provider.get_headers() if config.prod_environment_id: headers["x-dbt-prod-environment-id"] = str(config.prod_environment_id) if config.dev_environment_id: headers["x-dbt-dev-environment-id"] = str(config.dev_environment_id) if config.user_id: headers["x-dbt-user-id"] = str(config.user_id) proxied_tools_manager = ProxiedToolsManager() try: session = await proxied_tools_manager.get_remote_mcp_session( config.url, headers ) await session.initialize() tools = await get_proxied_tools(session, configured_proxied_tools) except BaseException as e: logger.error(f"Error getting proxied tools: {e}") return logger.info(f"Loaded proxied tools: {', '.join([tool.name for tool in tools])}") for tool in tools: # Create a new function using a factory to avoid closure issues def create_tool_function(tool_name: str): async def tool_function(*args, **kwargs) -> Sequence[ContentBlock]: tool_call_result = await session.call_tool( tool_name, kwargs, ) if tool_call_result.isError: raise RemoteToolError( f"Tool {tool_name} reported an error: " + f"{tool_call_result.content}" ) return tool_call_result.content return tool_function dbt_mcp._tool_manager._tools[tool.name] = InternalTool( fn=create_tool_function(tool.name), title=tool.title, name=tool.name, annotations=tool.annotations, description=tool.description or "", parameters=tool.inputSchema, fn_metadata=get_remote_tool_fn_metadata(tool), is_async=True, context_kwarg=None, )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/dbt-labs/dbt-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server