Databricks MCP Server

by JordiNeil
Verified
import os from typing import List, Dict, Any, Optional from dotenv import load_dotenv from databricks.sql import connect from databricks.sql.client import Connection from mcp.server.fastmcp import FastMCP import requests import json # Load environment variables load_dotenv() # Get Databricks credentials from environment variables DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_TOKEN = os.getenv("DATABRICKS_TOKEN") DATABRICKS_HTTP_PATH = os.getenv("DATABRICKS_HTTP_PATH") # Set up the MCP server mcp = FastMCP("Databricks API Explorer") # Helper function to get a Databricks SQL connection def get_databricks_connection() -> Connection: """Create and return a Databricks SQL connection""" if not all([DATABRICKS_HOST, DATABRICKS_TOKEN, DATABRICKS_HTTP_PATH]): raise ValueError("Missing required Databricks connection details in .env file") return connect( server_hostname=DATABRICKS_HOST, http_path=DATABRICKS_HTTP_PATH, access_token=DATABRICKS_TOKEN ) # Helper function for Databricks REST API requests def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None) -> Dict: """Make a request to the Databricks REST API""" if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): raise ValueError("Missing required Databricks API credentials in .env file") headers = { "Authorization": f"Bearer {DATABRICKS_TOKEN}", "Content-Type": "application/json" } url = f"https://{DATABRICKS_HOST}/api/2.0/{endpoint}" if method.upper() == "GET": response = requests.get(url, headers=headers) elif method.upper() == "POST": response = requests.post(url, headers=headers, json=data) else: raise ValueError(f"Unsupported HTTP method: {method}") response.raise_for_status() return response.json() @mcp.resource("schema://tables") def get_schema() -> str: """Provide the list of tables in the Databricks SQL warehouse as a resource""" conn = get_databricks_connection() try: cursor = conn.cursor() tables = cursor.tables().fetchall() table_info = [] for table in tables: table_info.append(f"Database: {table.TABLE_CAT}, Schema: {table.TABLE_SCHEM}, Table: {table.TABLE_NAME}") return "\n".join(table_info) except Exception as e: return f"Error retrieving tables: {str(e)}" finally: if 'conn' in locals(): conn.close() @mcp.tool() def run_sql_query(sql: str) -> str: """Execute SQL queries on Databricks SQL warehouse""" print(sql) conn = get_databricks_connection() print("connected") try: cursor = conn.cursor() result = cursor.execute(sql) if result.description: # Get column names columns = [col[0] for col in result.description] # Format the result as a table rows = result.fetchall() if not rows: return "Query executed successfully. No results returned." # Format as markdown table table = "| " + " | ".join(columns) + " |\n" table += "| " + " | ".join(["---" for _ in columns]) + " |\n" for row in rows: table += "| " + " | ".join([str(cell) for cell in row]) + " |\n" return table else: return "Query executed successfully. No results returned." except Exception as e: return f"Error executing query: {str(e)}" finally: if 'conn' in locals(): conn.close() @mcp.tool() def list_jobs() -> str: """List all Databricks jobs""" try: response = databricks_api_request("jobs/list") if not response.get("jobs"): return "No jobs found." jobs = response.get("jobs", []) # Format as markdown table table = "| Job ID | Job Name | Created By |\n" table += "| ------ | -------- | ---------- |\n" for job in jobs: job_id = job.get("job_id", "N/A") job_name = job.get("settings", {}).get("name", "N/A") created_by = job.get("created_by", "N/A") table += f"| {job_id} | {job_name} | {created_by} |\n" return table except Exception as e: return f"Error listing jobs: {str(e)}" @mcp.tool() def get_job_status(job_id: int) -> str: """Get the status of a specific Databricks job""" try: response = databricks_api_request("jobs/runs/list", data={"job_id": job_id}) if not response.get("runs"): return f"No runs found for job ID {job_id}." runs = response.get("runs", []) # Format as markdown table table = "| Run ID | State | Start Time | End Time | Duration |\n" table += "| ------ | ----- | ---------- | -------- | -------- |\n" for run in runs: run_id = run.get("run_id", "N/A") state = run.get("state", {}).get("result_state", "N/A") # Convert timestamps to readable format if they exist start_time = run.get("start_time", 0) end_time = run.get("end_time", 0) if start_time and end_time: duration = f"{(end_time - start_time) / 1000:.2f}s" else: duration = "N/A" # Format timestamps import datetime start_time_str = datetime.datetime.fromtimestamp(start_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if start_time else "N/A" end_time_str = datetime.datetime.fromtimestamp(end_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if end_time else "N/A" table += f"| {run_id} | {state} | {start_time_str} | {end_time_str} | {duration} |\n" return table except Exception as e: return f"Error getting job status: {str(e)}" @mcp.tool() def get_job_details(job_id: int) -> str: """Get detailed information about a specific Databricks job""" try: response = databricks_api_request(f"jobs/get?job_id={job_id}", method="GET") # Format the job details job_name = response.get("settings", {}).get("name", "N/A") created_time = response.get("created_time", 0) # Convert timestamp to readable format import datetime created_time_str = datetime.datetime.fromtimestamp(created_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if created_time else "N/A" # Get job tasks tasks = response.get("settings", {}).get("tasks", []) result = f"## Job Details: {job_name}\n\n" result += f"- **Job ID:** {job_id}\n" result += f"- **Created:** {created_time_str}\n" result += f"- **Creator:** {response.get('creator_user_name', 'N/A')}\n\n" if tasks: result += "### Tasks:\n\n" result += "| Task Key | Task Type | Description |\n" result += "| -------- | --------- | ----------- |\n" for task in tasks: task_key = task.get("task_key", "N/A") task_type = next(iter([k for k in task.keys() if k.endswith("_task")]), "N/A") description = task.get("description", "N/A") result += f"| {task_key} | {task_type} | {description} |\n" return result except Exception as e: return f"Error getting job details: {str(e)}" # if __name__ == "__main__": # import uvicorn # uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) if __name__ == "__main__": #run_sql_query("SELECT * FROM dev.dev_test.income_survey_dataset LIMIT 10;") mcp.run()