Skip to main content
Glama

dbt-mcp

Official
by dbt-labs
client.py14.9 kB
from collections.abc import Callable from contextlib import AbstractContextManager from typing import Any, Protocol import pyarrow as pa from dbtsl.api.shared.query_params import ( GroupByParam, OrderByGroupBy, OrderByMetric, OrderBySpec, ) from dbtsl.client.sync import SyncSemanticLayerClient from dbtsl.error import QueryFailedError from dbt_mcp.config.config_providers import ConfigProvider, SemanticLayerConfig from dbt_mcp.errors import InvalidParameterError from dbt_mcp.semantic_layer.gql.gql import GRAPHQL_QUERIES from dbt_mcp.semantic_layer.gql.gql_request import submit_request from dbt_mcp.semantic_layer.levenshtein import get_misspellings from dbt_mcp.semantic_layer.types import ( DimensionToolResponse, EntityToolResponse, GetMetricsCompiledSqlError, GetMetricsCompiledSqlResult, GetMetricsCompiledSqlSuccess, MetricToolResponse, OrderByParam, QueryMetricsError, QueryMetricsResult, QueryMetricsSuccess, SavedQueryToolResponse, ) def DEFAULT_RESULT_FORMATTER(table: pa.Table) -> str: return table.to_pandas().to_json(orient="records", indent=2) class SemanticLayerClientProtocol(Protocol): def session(self) -> AbstractContextManager[Any]: ... def query( self, metrics: list[str], group_by: list[GroupByParam | str] | None = None, limit: int | None = None, order_by: list[str | OrderByGroupBy | OrderByMetric] | None = None, where: list[str] | None = None, read_cache: bool = True, ) -> pa.Table: ... def compile_sql( self, metrics: list[str], group_by: list[str] | None = None, limit: int | None = None, order_by: list[str | OrderByGroupBy | OrderByMetric] | None = None, where: list[str] | None = None, read_cache: bool = True, ) -> str: ... class SemanticLayerClientProvider(Protocol): async def get_client(self) -> SemanticLayerClientProtocol: ... class DefaultSemanticLayerClientProvider: def __init__(self, config_provider: ConfigProvider[SemanticLayerConfig]): self.config_provider = config_provider async def get_client(self) -> SemanticLayerClientProtocol: config = await self.config_provider.get_config() return SyncSemanticLayerClient( environment_id=config.prod_environment_id, auth_token=config.token, host=config.host, ) class SemanticLayerFetcher: def __init__( self, config_provider: ConfigProvider[SemanticLayerConfig], client_provider: SemanticLayerClientProvider, ): self.client_provider = client_provider self.config_provider = config_provider self.entities_cache: dict[str, list[EntityToolResponse]] = {} self.dimensions_cache: dict[str, list[DimensionToolResponse]] = {} async def list_metrics(self, search: str | None = None) -> list[MetricToolResponse]: metrics_result = submit_request( await self.config_provider.get_config(), {"query": GRAPHQL_QUERIES["metrics"], "variables": {"search": search}}, ) return [ MetricToolResponse( name=m.get("name"), type=m.get("type"), label=m.get("label"), description=m.get("description"), metadata=(m.get("config") or {}).get("meta", ""), ) for m in metrics_result["data"]["metricsPaginated"]["items"] ] async def list_saved_queries( self, search: str | None = None ) -> list[SavedQueryToolResponse]: """Fetch all saved queries from the Semantic Layer API.""" saved_queries_result = submit_request( await self.config_provider.get_config(), { "query": GRAPHQL_QUERIES["saved_queries"], "variables": {"search": search}, }, ) return [ SavedQueryToolResponse( name=sq.get("name"), label=sq.get("label"), description=sq.get("description"), metrics=[ m.get("name") for m in sq.get("queryParams", {}).get("metrics", []) ] if sq.get("queryParams", {}).get("metrics") else None, group_by=[ g.get("name") for g in sq.get("queryParams", {}).get("groupBy", []) ] if sq.get("queryParams", {}).get("groupBy") else None, where=sq.get("queryParams", {}).get("where", {}).get("whereSqlTemplate") if sq.get("queryParams", {}).get("where") else None, ) for sq in saved_queries_result["data"]["savedQueriesPaginated"]["items"] ] async def get_dimensions( self, metrics: list[str], search: str | None = None ) -> list[DimensionToolResponse]: metrics_key = ",".join(sorted(metrics)) if metrics_key not in self.dimensions_cache: dimensions_result = submit_request( await self.config_provider.get_config(), { "query": GRAPHQL_QUERIES["dimensions"], "variables": { "metrics": [{"name": m} for m in metrics], "search": search, }, }, ) dimensions = [] for d in dimensions_result["data"]["dimensionsPaginated"]["items"]: dimensions.append( DimensionToolResponse( name=d.get("name"), type=d.get("type"), description=d.get("description"), label=d.get("label"), granularities=d.get("queryableGranularities") + d.get("queryableTimeGranularities"), ) ) self.dimensions_cache[metrics_key] = dimensions return self.dimensions_cache[metrics_key] async def get_entities( self, metrics: list[str], search: str | None = None ) -> list[EntityToolResponse]: metrics_key = ",".join(sorted(metrics)) if metrics_key not in self.entities_cache: entities_result = submit_request( await self.config_provider.get_config(), { "query": GRAPHQL_QUERIES["entities"], "variables": { "metrics": [{"name": m} for m in metrics], "search": search, }, }, ) entities = [ EntityToolResponse( name=e.get("name"), type=e.get("type"), description=e.get("description"), ) for e in entities_result["data"]["entitiesPaginated"]["items"] ] self.entities_cache[metrics_key] = entities return self.entities_cache[metrics_key] def _format_semantic_layer_error(self, error: Exception) -> str: """Format semantic layer errors by cleaning up common error message patterns.""" error_str = str(error) return ( error_str.replace("QueryFailedError(", "") .rstrip(")") .lstrip("[") .rstrip("]") .lstrip('"') .rstrip('"') .replace("INVALID_ARGUMENT: [FlightSQL]", "") .replace("(InvalidArgument; Prepare)", "") .replace("(InvalidArgument; ExecuteQuery)", "") .replace("Failed to prepare statement:", "") .replace("com.dbt.semanticlayer.exceptions.DataPlatformException:", "") .strip() ) def _format_get_metrics_compiled_sql_error( self, compile_error: Exception ) -> GetMetricsCompiledSqlError: """Format get compiled SQL errors using the shared error formatter.""" return GetMetricsCompiledSqlError( error=self._format_semantic_layer_error(compile_error) ) async def validate_query_metrics_params( self, metrics: list[str], group_by: list[GroupByParam] | None ) -> str | None: errors = [] available_metrics_names = [m.name for m in await self.list_metrics()] metric_misspellings = get_misspellings( targets=metrics, words=available_metrics_names, top_k=5, ) for metric_misspelling in metric_misspellings: recommendations = ( " Did you mean: " + ", ".join(metric_misspelling.similar_words) + "?" ) errors.append( f"Metric {metric_misspelling.word} not found." + (recommendations if metric_misspelling.similar_words else "") ) if errors: return f"Errors: {', '.join(errors)}" available_group_by = [d.name for d in await self.get_dimensions(metrics)] + [ e.name for e in await self.get_entities(metrics) ] group_by_misspellings = get_misspellings( targets=[g.name for g in group_by or []], words=available_group_by, top_k=5, ) for group_by_misspelling in group_by_misspellings: recommendations = ( " Did you mean: " + ", ".join(group_by_misspelling.similar_words) + "?" ) errors.append( f"Group by {group_by_misspelling.word} not found." + (recommendations if group_by_misspelling.similar_words else "") ) if errors: return f"Errors: {', '.join(errors)}" return None # TODO: move this to the SDK def _format_query_failed_error(self, query_error: Exception) -> QueryMetricsError: if isinstance(query_error, QueryFailedError): return QueryMetricsError( error=self._format_semantic_layer_error(query_error) ) else: return QueryMetricsError(error=str(query_error)) def _get_order_bys( self, order_by: list[OrderByParam] | None, metrics: list[str] = [], group_by: list[GroupByParam] | None = None, ) -> list[OrderBySpec]: result: list[OrderBySpec] = [] if order_by is None: return result queried_group_by = {g.name: g for g in group_by} if group_by else {} queried_metrics = set(metrics) for o in order_by: if o.name in queried_metrics: result.append(OrderByMetric(name=o.name, descending=o.descending)) elif o.name in queried_group_by: selected_group_by = queried_group_by[o.name] result.append( OrderByGroupBy( name=selected_group_by.name, descending=o.descending, grain=selected_group_by.grain, ) ) else: raise InvalidParameterError( f"Order by `{o.name}` not found in metrics or group by" ) return result async def get_metrics_compiled_sql( self, metrics: list[str], group_by: list[GroupByParam] | None = None, order_by: list[OrderByParam] | None = None, where: str | None = None, limit: int | None = None, ) -> GetMetricsCompiledSqlResult: """ Get compiled SQL for the given metrics and group by parameters using the SDK. Args: metrics: List of metric names to get compiled SQL for group_by: List of group by parameters (dimensions/entities with optional grain) order_by: List of order by parameters where: Optional SQL WHERE clause to filter results limit: Optional limit for number of results Returns: GetMetricsCompiledSqlResult with either the compiled SQL or an error """ validation_error = await self.validate_query_metrics_params( metrics=metrics, group_by=group_by, ) if validation_error: return GetMetricsCompiledSqlError(error=validation_error) try: sl_client = await self.client_provider.get_client() with sl_client.session(): parsed_order_by: list[OrderBySpec] = self._get_order_bys( order_by=order_by, metrics=metrics, group_by=group_by ) compiled_sql = sl_client.compile_sql( metrics=metrics, group_by=group_by, # type: ignore order_by=parsed_order_by, # type: ignore where=[where] if where else None, limit=limit, read_cache=True, ) return GetMetricsCompiledSqlSuccess(sql=compiled_sql) except Exception as e: return self._format_get_metrics_compiled_sql_error(e) async def query_metrics( self, metrics: list[str], group_by: list[GroupByParam] | None = None, order_by: list[OrderByParam] | None = None, where: str | None = None, limit: int | None = None, result_formatter: Callable[[pa.Table], str] = DEFAULT_RESULT_FORMATTER, ) -> QueryMetricsResult: validation_error = await self.validate_query_metrics_params( metrics=metrics, group_by=group_by, ) if validation_error: return QueryMetricsError(error=validation_error) try: query_error = None sl_client = await self.client_provider.get_client() with sl_client.session(): # Catching any exception within the session # to ensure it is closed properly try: parsed_order_by: list[OrderBySpec] = self._get_order_bys( order_by=order_by, metrics=metrics, group_by=group_by ) query_result = sl_client.query( metrics=metrics, group_by=group_by, # type: ignore order_by=parsed_order_by, # type: ignore where=[where] if where else None, limit=limit, ) except Exception as e: query_error = e if query_error: return self._format_query_failed_error(query_error) json_result = result_formatter(query_result) return QueryMetricsSuccess(result=json_result or "") except Exception as e: return self._format_query_failed_error(e)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/dbt-labs/dbt-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server