Skip to main content
Glama
databricks_client.py13.9 kB
""" Databricks client wrapper with async support for Unity Catalog operations. """ import asyncio import logging from typing import Any, Dict, List, Optional from concurrent.futures import ThreadPoolExecutor from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config from databricks.sdk.service.catalog import CatalogInfo, SchemaInfo, TableInfo, ColumnInfo from databricks.sdk.service.sql import StatementStatus from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class TableMetadata(BaseModel): """Enhanced table metadata model.""" name: str catalog_name: str schema_name: str table_type: str data_source_format: Optional[str] = None comment: Optional[str] = None owner: Optional[str] = None created_at: Optional[str] = None created_by: Optional[str] = None updated_at: Optional[str] = None updated_by: Optional[str] = None columns: List[Dict[str, Any]] = Field(default_factory=list) properties: Dict[str, Any] = Field(default_factory=dict) storage_location: Optional[str] = None class QueryResult(BaseModel): """Query execution result model.""" status: str data: List[Dict[str, Any]] = Field(default_factory=list) columns: List[str] = Field(default_factory=list) row_count: int = 0 execution_time_ms: Optional[int] = None error: Optional[str] = None class DatabricksClient: """Async wrapper for Databricks SDK operations.""" def __init__(self, workspace_client: WorkspaceClient): """Private constructor. Use create() class method instead.""" self.workspace_client = workspace_client self.executor = ThreadPoolExecutor(max_workers=4) @classmethod async def create(cls, config: Optional[Config] = None) -> "DatabricksClient": """Create and initialize a DatabricksClient instance.""" try: # Run client initialization in thread pool to avoid blocking loop = asyncio.get_event_loop() workspace_client = await loop.run_in_executor( None, cls._create_workspace_client, config or Config() ) logger.info("Databricks client initialized successfully") return cls(workspace_client) except Exception as e: logger.error(f"Failed to initialize Databricks client: {e}") raise @staticmethod def _create_workspace_client(config: Config) -> WorkspaceClient: """Create workspace client (runs in thread pool).""" return WorkspaceClient(config=config) async def _run_sync(self, func, *args, **kwargs) -> Any: """Run a synchronous function in the thread pool.""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.executor, func, *args, **kwargs) async def list_catalogs(self) -> List[CatalogInfo]: """List all catalogs in Unity Catalog.""" def _list_catalogs(): return list(self.workspace_client.catalogs.list()) return await self._run_sync(_list_catalogs) async def get_catalog_info(self, catalog_name: str) -> CatalogInfo: """Get detailed information about a catalog.""" def _get_catalog(): return self.workspace_client.catalogs.get(catalog_name) return await self._run_sync(_get_catalog) async def list_schemas(self, catalog_name: str) -> List[SchemaInfo]: """List all schemas in a catalog.""" def _list_schemas(): return list(self.workspace_client.schemas.list(catalog_name=catalog_name)) return await self._run_sync(_list_schemas) async def get_schema_info(self, catalog_name: str, schema_name: str) -> SchemaInfo: """Get detailed information about a schema.""" def _get_schema(): return self.workspace_client.schemas.get( full_name=f"{catalog_name}.{schema_name}" ) return await self._run_sync(_get_schema) async def list_tables(self, catalog_name: str, schema_name: str) -> List[TableInfo]: """List all tables in a schema.""" def _list_tables(): return list(self.workspace_client.tables.list( catalog_name=catalog_name, schema_name=schema_name )) return await self._run_sync(_list_tables) async def get_table_info(self, catalog_name: str, schema_name: str, table_name: str) -> TableInfo: """Get detailed information about a table.""" def _get_table(): return self.workspace_client.tables.get( full_name=f"{catalog_name}.{schema_name}.{table_name}" ) return await self._run_sync(_get_table) async def describe_table(self, catalog_name: str, schema_name: str, table_name: str) -> TableMetadata: """Get enhanced table metadata including columns and properties.""" table_info = await self.get_table_info(catalog_name, schema_name, table_name) # Convert to our enhanced metadata format columns = [] if table_info.columns: for col in table_info.columns: columns.append({ "name": col.name, "type": col.type_name, "nullable": col.nullable, "comment": col.comment, "type_precision": col.type_precision, "type_scale": col.type_scale, }) return TableMetadata( name=table_info.name or "", catalog_name=table_info.catalog_name or "", schema_name=table_info.schema_name or "", table_type=table_info.table_type.value if table_info.table_type else "UNKNOWN", data_source_format=table_info.data_source_format.value if table_info.data_source_format else None, comment=table_info.comment, owner=table_info.owner, created_at=str(table_info.created_at) if table_info.created_at else None, created_by=table_info.created_by, updated_at=str(table_info.updated_at) if table_info.updated_at else None, updated_by=table_info.updated_by, columns=columns, properties=table_info.properties or {}, storage_location=table_info.storage_location, ) async def sample_table(self, catalog_name: str, schema_name: str, table_name: str, limit: int = 10) -> QueryResult: """Sample data from a table.""" full_table_name = f"{catalog_name}.{schema_name}.{table_name}" query = f"SELECT * FROM {full_table_name} LIMIT {limit}" return await self.execute_query(query) async def execute_query(self, query: str, warehouse_id: Optional[str] = None) -> QueryResult: """Execute a SQL query and return results.""" def _execute_query(): try: # Use the first available warehouse if none specified if not warehouse_id: warehouses = list(self.workspace_client.warehouses.list()) if not warehouses: raise ValueError("No SQL warehouses available") selected_warehouse_id = warehouses[0].id if not selected_warehouse_id: raise ValueError("Selected warehouse has no ID") else: selected_warehouse_id = warehouse_id # Execute the query execution = self.workspace_client.statement_execution.execute_statement( statement=query, warehouse_id=selected_warehouse_id, wait_timeout="30s" ) # Ensure we have a statement_id if not execution.statement_id: raise ValueError("No statement ID returned from execution") # Wait for completion import time while execution.status and execution.status.state in ["PENDING", "RUNNING"]: time.sleep(0.1) # Small delay to avoid hammering the API if not execution.statement_id: raise ValueError("Statement ID became None during execution") execution = self.workspace_client.statement_execution.get_statement(execution.statement_id) if execution.status and execution.status.state == "SUCCEEDED": result_data = execution.result.data_array if execution.result else [] columns = [] if execution.manifest and execution.manifest.schema and execution.manifest.schema.columns: columns = [col.name for col in execution.manifest.schema.columns if col.name] # Convert result data to dict format formatted_data = [] for row in result_data or []: row_dict = {} for i, value in enumerate(row): if i < len(columns): row_dict[columns[i]] = value formatted_data.append(row_dict) return QueryResult( status="SUCCESS", data=formatted_data, columns=columns, row_count=len(formatted_data), execution_time_ms=getattr(execution.status, 'duration_ms', None) ) else: error_msg = "Unknown error" if execution.status and hasattr(execution.status, 'error') and execution.status.error: error_msg = getattr(execution.status.error, 'message', str(execution.status.error)) return QueryResult( status="ERROR", error=error_msg ) except Exception as e: return QueryResult( status="ERROR", error=str(e) ) return await self._run_sync(_execute_query) async def get_table_lineage(self, catalog_name: str, schema_name: str, table_name: str) -> Dict[str, Any]: """Get lineage information for a table.""" def _get_lineage(): try: full_name = f"{catalog_name}.{schema_name}.{table_name}" # Note: Lineage API might not be available in all Databricks versions # This is a placeholder implementation return { "table": full_name, "upstream_tables": [], "downstream_tables": [], "lineage_available": False, "message": "Lineage API not implemented in this version" } except Exception as e: return { "table": f"{catalog_name}.{schema_name}.{table_name}", "error": str(e) } return await self._run_sync(_get_lineage) async def search_tables(self, search_term: str, catalog_name: Optional[str] = None, schema_name: Optional[str] = None) -> List[TableInfo]: """Search for tables by name or metadata.""" def _search_tables(): results = [] # Get list of catalogs to search if catalog_name: catalogs_to_search = [catalog_name] else: catalogs_to_search = [cat.name for cat in self.workspace_client.catalogs.list() if cat.name] for cat_name in catalogs_to_search: if not cat_name: continue try: # Get schemas in this catalog if schema_name: schemas_to_search = [schema_name] else: schemas_to_search = [schema.name for schema in self.workspace_client.schemas.list(catalog_name=cat_name) if schema.name] for schema_name_iter in schemas_to_search: if not schema_name_iter: continue try: # Get tables in this schema tables = list(self.workspace_client.tables.list( catalog_name=cat_name, schema_name=schema_name_iter )) # Filter tables by search term for table in tables: if table.name and (search_term.lower() in table.name.lower() or (table.comment and search_term.lower() in table.comment.lower())): results.append(table) except Exception as e: logger.warning(f"Error searching schema {cat_name}.{schema_name_iter}: {e}") continue except Exception as e: logger.warning(f"Error searching catalog {cat_name}: {e}") continue return results return await self._run_sync(_search_tables) async def close(self) -> None: """Close the client and clean up resources.""" if self.executor: self.executor.shutdown(wait=True) logger.info("Databricks client closed")

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/knustx/databricks-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server