MCP Databricks Server

from typing import Any, Dict, Optional import os import asyncio import httpx from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Configuration constants DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST", "") DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN", "") DATABRICKS_SQL_WAREHOUSE_ID = os.environ.get("DATABRICKS_SQL_WAREHOUSE_ID", "") # API endpoints STATEMENTS_API = "/api/2.0/sql/statements" STATEMENT_API = "/api/2.0/sql/statements/{statement_id}" async def make_databricks_request( method: str, endpoint: str, json_data: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Make a request to the Databricks API with proper error handling.""" url = f"{DATABRICKS_HOST}{endpoint}" headers = { "Authorization": f"Bearer {DATABRICKS_TOKEN}", "Content-Type": "application/json" } async with httpx.AsyncClient() as client: try: if method.lower() == "get": response = await client.get(url, headers=headers, params=params, timeout=30.0) elif method.lower() == "post": response = await client.post(url, headers=headers, json=json_data, timeout=30.0) else: raise ValueError(f"Unsupported HTTP method: {method}") response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: error_message = f"HTTP error: {e.response.status_code}" try: error_detail = e.response.json() error_message += f" - {error_detail.get('message', '')}" except Exception: pass raise Exception(error_message) except Exception as e: raise Exception(f"Error making request to Databricks API: {str(e)}") async def execute_statement(sql: str, warehouse_id: Optional[str] = None) -> Dict[str, Any]: """Execute a SQL statement and wait for its completion.""" if not warehouse_id: warehouse_id = DATABRICKS_SQL_WAREHOUSE_ID if not warehouse_id: raise ValueError("Warehouse ID is required. Set DATABRICKS_SQL_WAREHOUSE_ID environment variable or provide it as a parameter.") # Create the statement statement_data = { "statement": sql, "warehouse_id": warehouse_id, "wait_timeout": "0s" # Don't wait for completion in the initial request } response = await make_databricks_request("post", STATEMENTS_API, json_data=statement_data) statement_id = response.get("statement_id") if not statement_id: raise Exception("Failed to get statement ID from response") # Poll for statement completion max_retries = 60 # Maximum number of retries (10 minutes with 10-second intervals) retry_count = 0 while retry_count < max_retries: statement_status = await make_databricks_request( "get", STATEMENT_API.format(statement_id=statement_id) ) status = statement_status.get("status", {}).get("state") if status == "SUCCEEDED": return statement_status elif status in ["FAILED", "CANCELED"]: error_message = statement_status.get("status", {}).get("error", {}).get("message", "Unknown error") raise Exception(f"Statement execution failed: {error_message}") # Wait before polling again await asyncio.sleep(10) retry_count += 1 raise Exception("Statement execution timed out")