#!/usr/bin/env python3
"""
Databricks MCP Server
A Model Context Protocol server that provides integration with Databricks Unity Catalog.
Supports querying metadata, sampling data, and using the Databricks SDK.
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp.server import Server
from mcp.server.lowlevel.server import NotificationOptions
from mcp.server.models import InitializationOptions
from mcp.server.stdio import stdio_server
from mcp.types import (
Resource,
Tool,
TextContent,
LoggingLevel,
Prompt,
GetPromptResult
)
from pydantic import Field, AnyUrl
from .databricks_client import DatabricksClient
from .unity_catalog import UnityCatalogManager
from .prompts import PromptsManager
from .utils import setup_logging
# Configure logging
logger = logging.getLogger(__name__)
class DatabricksMCPServer:
"""Main MCP server for Databricks integration."""
def __init__(self):
self.server = Server("databricks-mcp-server")
self.databricks_client: Optional[DatabricksClient] = None
self.unity_catalog: Optional[UnityCatalogManager] = None
self.prompts_manager = PromptsManager()
self._setup_handlers()
def _setup_handlers(self) -> None:
"""Set up MCP server handlers."""
@self.server.list_resources()
async def list_resources() -> List[Resource]:
"""List available Databricks resources."""
resources = []
if self.databricks_client:
# List catalogs as resources
try:
catalogs = await self.databricks_client.list_catalogs()
for catalog_info in catalogs:
resources.append(
Resource(
uri=AnyUrl(f"databricks://catalog/{catalog_info.name}"),
name=f"Catalog: {catalog_info.name}",
description=catalog_info.comment or "Databricks catalog",
mimeType="application/json"
)
)
except Exception as e:
logger.error(f"Failed to list catalogs: {e}")
return resources
@self.server.read_resource()
async def read_resource(uri: AnyUrl) -> str:
"""Read a specific Databricks resource."""
if not self.databricks_client:
raise ValueError("Databricks client not initialized")
parsed_uri = urlparse(str(uri))
if parsed_uri.scheme != "databricks":
raise ValueError(f"Unsupported URI scheme: {parsed_uri.scheme}")
path_parts = parsed_uri.path.strip("/").split("/")
if len(path_parts) >= 2 and path_parts[0] == "catalog":
catalog_name = path_parts[1]
if len(path_parts) == 2:
# Return catalog info
catalog_info = await self.databricks_client.get_catalog_info(catalog_name)
return str(catalog_info)
elif len(path_parts) >= 3:
schema_name = path_parts[2]
if len(path_parts) == 3:
# Return schema info
schema_info = await self.databricks_client.get_schema_info(
catalog_name, schema_name
)
return str(schema_info)
elif len(path_parts) >= 4:
table_name = path_parts[3]
# Return table info
table_info = await self.databricks_client.get_table_info(
catalog_name, schema_name, table_name
)
return str(table_info)
raise ValueError(f"Invalid resource URI: {uri}")
@self.server.list_tools()
async def list_tools() -> List[Tool]:
"""List available tools."""
return [
Tool(
name="list_catalogs",
description="List all Unity Catalog catalogs",
inputSchema={
"type": "object",
"properties": {},
"required": []
}
),
Tool(
name="list_schemas",
description="List schemas in a catalog",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {
"type": "string",
"description": "Name of the catalog"
}
},
"required": ["catalog_name"]
}
),
Tool(
name="list_tables",
description="List tables in a schema",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {
"type": "string",
"description": "Name of the catalog"
},
"schema_name": {
"type": "string",
"description": "Name of the schema"
}
},
"required": ["catalog_name", "schema_name"]
}
),
Tool(
name="describe_table",
description="Get detailed information about a table",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {
"type": "string",
"description": "Name of the catalog"
},
"schema_name": {
"type": "string",
"description": "Name of the schema"
},
"table_name": {
"type": "string",
"description": "Name of the table"
}
},
"required": ["catalog_name", "schema_name", "table_name"]
}
),
Tool(
name="sample_table",
description="Sample data from a table",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {
"type": "string",
"description": "Name of the catalog"
},
"schema_name": {
"type": "string",
"description": "Name of the schema"
},
"table_name": {
"type": "string",
"description": "Name of the table"
},
"limit": {
"type": "integer",
"description": "Number of rows to sample",
"default": 10
}
},
"required": ["catalog_name", "schema_name", "table_name"]
}
),
Tool(
name="execute_query",
description="Execute a SQL query",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute"
},
"warehouse_id": {
"type": "string",
"description": "Warehouse ID to use for query execution"
}
},
"required": ["query"]
}
),
Tool(
name="get_table_lineage",
description="Get lineage information for a table",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {
"type": "string",
"description": "Name of the catalog"
},
"schema_name": {
"type": "string",
"description": "Name of the schema"
},
"table_name": {
"type": "string",
"description": "Name of the table"
}
},
"required": ["catalog_name", "schema_name", "table_name"]
}
),
Tool(
name="search_tables",
description="Search for tables by name or metadata",
inputSchema={
"type": "object",
"properties": {
"search_term": {
"type": "string",
"description": "Search term to look for in table names or metadata"
},
"catalog_name": {
"type": "string",
"description": "Optional catalog name to limit search"
},
"schema_name": {
"type": "string",
"description": "Optional schema name to limit search"
}
},
"required": ["search_term"]
}
)
]
@self.server.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool calls."""
if not self.databricks_client:
raise ValueError("Databricks client not initialized")
try:
if name == "list_catalogs":
result = await self.databricks_client.list_catalogs()
return [TextContent(type="text", text=str(result))]
elif name == "list_schemas":
catalog_name = arguments["catalog_name"]
result = await self.databricks_client.list_schemas(catalog_name)
return [TextContent(type="text", text=str(result))]
elif name == "list_tables":
catalog_name = arguments["catalog_name"]
schema_name = arguments["schema_name"]
result = await self.databricks_client.list_tables(catalog_name, schema_name)
return [TextContent(type="text", text=str(result))]
elif name == "describe_table":
catalog_name = arguments["catalog_name"]
schema_name = arguments["schema_name"]
table_name = arguments["table_name"]
result = await self.databricks_client.describe_table(
catalog_name, schema_name, table_name
)
return [TextContent(type="text", text=str(result))]
elif name == "sample_table":
catalog_name = arguments["catalog_name"]
schema_name = arguments["schema_name"]
table_name = arguments["table_name"]
limit = arguments.get("limit", 10)
result = await self.databricks_client.sample_table(
catalog_name, schema_name, table_name, limit
)
return [TextContent(type="text", text=str(result))]
elif name == "execute_query":
query = arguments["query"]
warehouse_id = arguments.get("warehouse_id")
result = await self.databricks_client.execute_query(query, warehouse_id)
return [TextContent(type="text", text=str(result))]
elif name == "get_table_lineage":
catalog_name = arguments["catalog_name"]
schema_name = arguments["schema_name"]
table_name = arguments["table_name"]
result = await self.databricks_client.get_table_lineage(
catalog_name, schema_name, table_name
)
return [TextContent(type="text", text=str(result))]
elif name == "search_tables":
search_term = arguments["search_term"]
catalog_name = arguments.get("catalog_name")
schema_name = arguments.get("schema_name")
result = await self.databricks_client.search_tables(
search_term, catalog_name, schema_name
)
return [TextContent(type="text", text=str(result))]
else:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
error_msg = f"Error executing tool {name}: {str(e)}"
logger.error(error_msg)
return [TextContent(type="text", text=error_msg)]
@self.server.list_prompts()
async def list_prompts() -> List[Prompt]:
"""List available prompts."""
return self.prompts_manager.list_prompts()
@self.server.get_prompt()
async def get_prompt(name: str, arguments: Optional[Dict[str, str]] = None) -> GetPromptResult:
"""Get a specific prompt with arguments."""
return self.prompts_manager.get_prompt(name, arguments)
async def initialize(self) -> None:
"""Initialize the Databricks client."""
try:
self.databricks_client = await DatabricksClient.create()
self.unity_catalog = UnityCatalogManager(self.databricks_client)
logger.info("Databricks MCP server initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Databricks client: {e}")
raise
async def run(self) -> None:
"""Run the MCP server."""
await self.initialize()
async with stdio_server() as (read_stream, write_stream):
await self.server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="databricks-mcp-server",
server_version="0.1.0",
capabilities=self.server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
async def main() -> None:
"""Main entry point."""
setup_logging()
server = DatabricksMCPServer()
await server.run()
if __name__ == "__main__":
asyncio.run(main())