"""
Databricks client wrapper with async support for Unity Catalog operations.
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import Config
from databricks.sdk.service.catalog import CatalogInfo, SchemaInfo, TableInfo, ColumnInfo
from databricks.sdk.service.sql import StatementStatus
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class TableMetadata(BaseModel):
"""Enhanced table metadata model."""
name: str
catalog_name: str
schema_name: str
table_type: str
data_source_format: Optional[str] = None
comment: Optional[str] = None
owner: Optional[str] = None
created_at: Optional[str] = None
created_by: Optional[str] = None
updated_at: Optional[str] = None
updated_by: Optional[str] = None
columns: List[Dict[str, Any]] = Field(default_factory=list)
properties: Dict[str, Any] = Field(default_factory=dict)
storage_location: Optional[str] = None
class QueryResult(BaseModel):
"""Query execution result model."""
status: str
data: List[Dict[str, Any]] = Field(default_factory=list)
columns: List[str] = Field(default_factory=list)
row_count: int = 0
execution_time_ms: Optional[int] = None
error: Optional[str] = None
class DatabricksClient:
"""Async wrapper for Databricks SDK operations."""
def __init__(self, workspace_client: WorkspaceClient):
"""Private constructor. Use create() class method instead."""
self.workspace_client = workspace_client
self.executor = ThreadPoolExecutor(max_workers=4)
@classmethod
async def create(cls, config: Optional[Config] = None) -> "DatabricksClient":
"""Create and initialize a DatabricksClient instance."""
try:
# Run client initialization in thread pool to avoid blocking
loop = asyncio.get_event_loop()
workspace_client = await loop.run_in_executor(
None, cls._create_workspace_client, config or Config()
)
logger.info("Databricks client initialized successfully")
return cls(workspace_client)
except Exception as e:
logger.error(f"Failed to initialize Databricks client: {e}")
raise
@staticmethod
def _create_workspace_client(config: Config) -> WorkspaceClient:
"""Create workspace client (runs in thread pool)."""
return WorkspaceClient(config=config)
async def _run_sync(self, func, *args, **kwargs) -> Any:
"""Run a synchronous function in the thread pool."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args, **kwargs)
async def list_catalogs(self) -> List[CatalogInfo]:
"""List all catalogs in Unity Catalog."""
def _list_catalogs():
return list(self.workspace_client.catalogs.list())
return await self._run_sync(_list_catalogs)
async def get_catalog_info(self, catalog_name: str) -> CatalogInfo:
"""Get detailed information about a catalog."""
def _get_catalog():
return self.workspace_client.catalogs.get(catalog_name)
return await self._run_sync(_get_catalog)
async def list_schemas(self, catalog_name: str) -> List[SchemaInfo]:
"""List all schemas in a catalog."""
def _list_schemas():
return list(self.workspace_client.schemas.list(catalog_name=catalog_name))
return await self._run_sync(_list_schemas)
async def get_schema_info(self, catalog_name: str, schema_name: str) -> SchemaInfo:
"""Get detailed information about a schema."""
def _get_schema():
return self.workspace_client.schemas.get(
full_name=f"{catalog_name}.{schema_name}"
)
return await self._run_sync(_get_schema)
async def list_tables(self, catalog_name: str, schema_name: str) -> List[TableInfo]:
"""List all tables in a schema."""
def _list_tables():
return list(self.workspace_client.tables.list(
catalog_name=catalog_name,
schema_name=schema_name
))
return await self._run_sync(_list_tables)
async def get_table_info(self, catalog_name: str, schema_name: str, table_name: str) -> TableInfo:
"""Get detailed information about a table."""
def _get_table():
return self.workspace_client.tables.get(
full_name=f"{catalog_name}.{schema_name}.{table_name}"
)
return await self._run_sync(_get_table)
async def describe_table(self, catalog_name: str, schema_name: str, table_name: str) -> TableMetadata:
"""Get enhanced table metadata including columns and properties."""
table_info = await self.get_table_info(catalog_name, schema_name, table_name)
# Convert to our enhanced metadata format
columns = []
if table_info.columns:
for col in table_info.columns:
columns.append({
"name": col.name,
"type": col.type_name,
"nullable": col.nullable,
"comment": col.comment,
"type_precision": col.type_precision,
"type_scale": col.type_scale,
})
return TableMetadata(
name=table_info.name or "",
catalog_name=table_info.catalog_name or "",
schema_name=table_info.schema_name or "",
table_type=table_info.table_type.value if table_info.table_type else "UNKNOWN",
data_source_format=table_info.data_source_format.value if table_info.data_source_format else None,
comment=table_info.comment,
owner=table_info.owner,
created_at=str(table_info.created_at) if table_info.created_at else None,
created_by=table_info.created_by,
updated_at=str(table_info.updated_at) if table_info.updated_at else None,
updated_by=table_info.updated_by,
columns=columns,
properties=table_info.properties or {},
storage_location=table_info.storage_location,
)
async def sample_table(self, catalog_name: str, schema_name: str, table_name: str, limit: int = 10) -> QueryResult:
"""Sample data from a table."""
full_table_name = f"{catalog_name}.{schema_name}.{table_name}"
query = f"SELECT * FROM {full_table_name} LIMIT {limit}"
return await self.execute_query(query)
async def execute_query(self, query: str, warehouse_id: Optional[str] = None) -> QueryResult:
"""Execute a SQL query and return results."""
def _execute_query():
try:
# Use the first available warehouse if none specified
if not warehouse_id:
warehouses = list(self.workspace_client.warehouses.list())
if not warehouses:
raise ValueError("No SQL warehouses available")
selected_warehouse_id = warehouses[0].id
if not selected_warehouse_id:
raise ValueError("Selected warehouse has no ID")
else:
selected_warehouse_id = warehouse_id
# Execute the query
execution = self.workspace_client.statement_execution.execute_statement(
statement=query,
warehouse_id=selected_warehouse_id,
wait_timeout="30s"
)
# Ensure we have a statement_id
if not execution.statement_id:
raise ValueError("No statement ID returned from execution")
# Wait for completion
import time
while execution.status and execution.status.state in ["PENDING", "RUNNING"]:
time.sleep(0.1) # Small delay to avoid hammering the API
if not execution.statement_id:
raise ValueError("Statement ID became None during execution")
execution = self.workspace_client.statement_execution.get_statement(execution.statement_id)
if execution.status and execution.status.state == "SUCCEEDED":
result_data = execution.result.data_array if execution.result else []
columns = []
if execution.manifest and execution.manifest.schema and execution.manifest.schema.columns:
columns = [col.name for col in execution.manifest.schema.columns if col.name]
# Convert result data to dict format
formatted_data = []
for row in result_data or []:
row_dict = {}
for i, value in enumerate(row):
if i < len(columns):
row_dict[columns[i]] = value
formatted_data.append(row_dict)
return QueryResult(
status="SUCCESS",
data=formatted_data,
columns=columns,
row_count=len(formatted_data),
execution_time_ms=getattr(execution.status, 'duration_ms', None)
)
else:
error_msg = "Unknown error"
if execution.status and hasattr(execution.status, 'error') and execution.status.error:
error_msg = getattr(execution.status.error, 'message', str(execution.status.error))
return QueryResult(
status="ERROR",
error=error_msg
)
except Exception as e:
return QueryResult(
status="ERROR",
error=str(e)
)
return await self._run_sync(_execute_query)
async def get_table_lineage(self, catalog_name: str, schema_name: str, table_name: str) -> Dict[str, Any]:
"""Get lineage information for a table."""
def _get_lineage():
try:
full_name = f"{catalog_name}.{schema_name}.{table_name}"
# Note: Lineage API might not be available in all Databricks versions
# This is a placeholder implementation
return {
"table": full_name,
"upstream_tables": [],
"downstream_tables": [],
"lineage_available": False,
"message": "Lineage API not implemented in this version"
}
except Exception as e:
return {
"table": f"{catalog_name}.{schema_name}.{table_name}",
"error": str(e)
}
return await self._run_sync(_get_lineage)
async def search_tables(self, search_term: str, catalog_name: Optional[str] = None,
schema_name: Optional[str] = None) -> List[TableInfo]:
"""Search for tables by name or metadata."""
def _search_tables():
results = []
# Get list of catalogs to search
if catalog_name:
catalogs_to_search = [catalog_name]
else:
catalogs_to_search = [cat.name for cat in self.workspace_client.catalogs.list() if cat.name]
for cat_name in catalogs_to_search:
if not cat_name:
continue
try:
# Get schemas in this catalog
if schema_name:
schemas_to_search = [schema_name]
else:
schemas_to_search = [schema.name for schema in self.workspace_client.schemas.list(catalog_name=cat_name) if schema.name]
for schema_name_iter in schemas_to_search:
if not schema_name_iter:
continue
try:
# Get tables in this schema
tables = list(self.workspace_client.tables.list(
catalog_name=cat_name,
schema_name=schema_name_iter
))
# Filter tables by search term
for table in tables:
if table.name and (search_term.lower() in table.name.lower() or
(table.comment and search_term.lower() in table.comment.lower())):
results.append(table)
except Exception as e:
logger.warning(f"Error searching schema {cat_name}.{schema_name_iter}: {e}")
continue
except Exception as e:
logger.warning(f"Error searching catalog {cat_name}: {e}")
continue
return results
return await self._run_sync(_search_tables)
async def close(self) -> None:
"""Close the client and clean up resources."""
if self.executor:
self.executor.shutdown(wait=True)
logger.info("Databricks client closed")