import asyncio
import json
import logging
import os
from pathlib import Path
import sys
from typing import Any, Dict, List
from mcp.server import FastMCP
from mcp.types import TextContent
from src.api import sql
from src.core.utils import DatabricksAPIError
# Configure logging
logging.basicConfig(
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO")),
filename="databricks_mcp.log",
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
class DatabricksMCPServer(FastMCP):
"""An MCP server for Databricks APIs."""
def __init__(self):
"""Initialize the Databricks MCP server."""
super().__init__(
name="databricks-mcp",
version="1.0.0",
instructions="Use this server to manage Databricks resources",
)
logger.info("Initializing Databricks MCP server")
# Register tools and resources
self._register_tools()
self._register_resources()
def _register_resources(self):
"""Register all Databricks MCP resources."""
@self.resource(
uri="databricks://schemas/gold-catalog-reference",
name="databricks_gold_schema_reference",
description=(
"Reference documentation for table schemas in the gold data layer "
"(us_dpe_production_gold catalog) in Databricks. Contains table names, "
"column definitions, data types, and nullability for all gold-layer "
"tables. Returns the local file path so the agent can read the file "
"as needed."
),
mime_type="text/plain",
)
def get_gold_schema_reference() -> str:
"""Return the local file path to the gold catalog schema reference.
Returns:
The absolute file path to the gold catalog schema reference
markdown file, allowing the agent to read it as needed.
"""
project_root = Path(__file__).resolve().parent.parent.parent
schema_path = project_root / "skills" / "databricks-schemas-gold.md"
logger.info(f"Returning gold schema reference path: {schema_path}")
return str(schema_path)
@self.resource(
uri="databricks://schemas/silver-catalog-reference",
name="databricks_silver_schema_reference",
description=(
"Reference documentation for table schemas in the silver data layer "
"(us_dpe_production_silver catalog) in Databricks. Contains table names, "
"column definitions, data types, and nullability for all silver-layer "
"tables. Returns the local file path so the agent can read the file "
"as needed."
),
mime_type="text/plain",
)
def get_silver_schema_reference() -> str:
"""Return the local file path to the silver catalog schema reference.
Returns:
The absolute file path to the silver catalog schema reference
markdown file, allowing the agent to read it as needed.
"""
project_root = Path(__file__).resolve().parent.parent.parent
schema_path = project_root / "skills" / "databricks-schemas-silver.md"
logger.info(f"Returning silver schema reference path: {schema_path}")
return str(schema_path)
@self.resource(
uri="databricks://permissions/grant-catalog-access",
name="databricks_grant_catalog_access",
description=(
"Instructions for granting the current user access to the "
"us_dpe_production_gold and us_dpe_production_silver catalogs "
"and their schemas in Databricks. "
"Use this resource when a SQL query fails due to insufficient "
"permissions, such as missing access to a table, schema, or catalog. "
"Returns a set of SQL GRANT statements that should be executed via "
"the execute_sql tool to resolve the permission error."
),
mime_type="text/plain",
)
def get_grant_catalog_access() -> str:
try:
email = os.environ["USER"] + "@forhims.com"
except KeyError:
return "Error: USER environment variable is not set."
logger.info("Returning grant catalog access SQL for user: %s", email)
gold_grants = (
f"GRANT USE CATALOG ON CATALOG us_dpe_production_gold TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_gold.bi TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_gold.bi TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_gold.core TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_gold.core TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_gold.bi_marketing TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_gold.bi_marketing TO `{email}`;\n"
)
silver_grants = (
f"GRANT USE CATALOG ON CATALOG us_dpe_production_silver TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.base TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.base TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.intermediate TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.intermediate TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.datamarts TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.datamarts TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.fed_base TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.fed_base TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.pre_calc TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.pre_calc TO `{email}`;\n"
f"GRANT USE SCHEMA ON SCHEMA us_dpe_production_silver.ref_all_fed TO `{email}`;\n"
f"GRANT SELECT ON SCHEMA us_dpe_production_silver.ref_all_fed TO `{email}`;\n"
)
return gold_grants + "\n" + silver_grants
def _register_tools(self):
"""Register all Databricks MCP tools."""
# SQL tools
@self.tool(
name="execute_sql",
description="Execute a SQL statement with parameters: statement (required), catalog (optional), schema (optional)",
)
async def execute_sql(params: Dict[str, Any]) -> List[TextContent]:
logger.info(f"Executing SQL with params: {params}")
warehouse_id = os.environ.get("DATABRICKS_WAREHOUSE_ID")
if not warehouse_id:
raise ValueError(
"DATABRICKS_WAREHOUSE_ID environment variable is not set"
)
try:
return [
{
"text": json.dumps(
await sql.execute_and_wait(
params.get("statement"),
warehouse_id,
params.get("catalog"),
params.get("schema"),
)
)
}
]
except (DatabricksAPIError, TimeoutError, ValueError) as e:
logger.error(f"Error executing SQL: {str(e)}")
return [{"text": json.dumps({"error": str(e)})}]
async def main():
"""Main entry point for the MCP server."""
try:
logger.info("Starting Databricks MCP server")
server = DatabricksMCPServer()
# Use the built-in method for stdio servers
# This is the recommended approach for MCP servers
await server.run_stdio_async()
except Exception as e:
logger.error(f"Error in Databricks MCP server: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# Turn off buffering in stdout
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(line_buffering=True)
asyncio.run(main())