"""
Utility functions for the Databricks MCP Server.
"""
import logging
import os
import sys
from typing import Any, Dict, Optional
from pathlib import Path
import json
from datetime import datetime, timezone
def setup_logging(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""Set up logging configuration."""
# Convert string level to logging constant
numeric_level = getattr(logging, level.upper(), logging.INFO)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(numeric_level)
# Remove existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Console handler
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setLevel(numeric_level)
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (optional)
if log_file:
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(numeric_level)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
# Set specific logger levels
logging.getLogger("databricks").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
"""Load configuration from file or environment variables."""
config = {}
# Load from file if specified
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
file_config = json.load(f)
config.update(file_config)
# Override with environment variables
env_config = {
"databricks_host": os.getenv("DATABRICKS_HOST"),
"databricks_token": os.getenv("DATABRICKS_TOKEN"),
"databricks_warehouse_id": os.getenv("DATABRICKS_WAREHOUSE_ID"),
"log_level": os.getenv("LOG_LEVEL", "INFO"),
"log_file": os.getenv("LOG_FILE"),
}
# Filter out None values and update config
env_config = {k: v for k, v in env_config.items() if v is not None}
config.update(env_config)
return config
def format_table_info(table_info: Any) -> str:
"""Format table information for display."""
try:
if hasattr(table_info, '__dict__'):
# Convert object to dict
info_dict = {}
for key, value in table_info.__dict__.items():
if not key.startswith('_'):
if hasattr(value, '__dict__'):
info_dict[key] = str(value)
elif isinstance(value, list):
info_dict[key] = [str(item) for item in value]
else:
info_dict[key] = value
return json.dumps(info_dict, indent=2, default=str)
else:
return str(table_info)
except Exception as e:
return f"Error formatting table info: {e}"
def format_query_result(result: Any) -> str:
"""Format query result for display."""
try:
if hasattr(result, 'model_dump'):
# Pydantic model
return json.dumps(result.model_dump(), indent=2, default=str)
elif hasattr(result, '__dict__'):
# Regular object
return json.dumps(result.__dict__, indent=2, default=str)
else:
return str(result)
except Exception as e:
return f"Error formatting query result: {e}"
def sanitize_table_name(name: str) -> str:
"""Sanitize table name for use in SQL queries."""
# Remove or replace invalid characters
sanitized = name.replace("`", "``") # Escape backticks
return f"`{sanitized}`"
def build_full_table_name(catalog: str, schema: str, table: str) -> str:
"""Build a fully qualified table name."""
return f"{sanitize_table_name(catalog)}.{sanitize_table_name(schema)}.{sanitize_table_name(table)}"
def parse_full_table_name(full_name: str) -> tuple[str, str, str]:
"""Parse a fully qualified table name into its components."""
parts = full_name.split(".")
if len(parts) != 3:
raise ValueError(f"Invalid table name format: {full_name}. Expected catalog.schema.table")
# Remove backticks if present
catalog = parts[0].strip("`")
schema = parts[1].strip("`")
table = parts[2].strip("`")
return catalog, schema, table
def estimate_query_cost(query: str) -> Dict[str, Any]:
"""Provide a basic estimation of query complexity/cost."""
query_lower = query.lower()
cost_factors = {
"has_join": "join" in query_lower,
"has_aggregation": any(agg in query_lower for agg in ["sum(", "count(", "avg(", "max(", "min(", "group by"]),
"has_window_function": any(win in query_lower for win in ["over(", "partition by", "row_number(", "rank("]),
"has_subquery": "(" in query and "select" in query_lower,
"approximate_complexity": "low"
}
# Simple complexity estimation
complexity_score = sum([
cost_factors["has_join"] * 2,
cost_factors["has_aggregation"] * 1,
cost_factors["has_window_function"] * 3,
cost_factors["has_subquery"] * 2
])
if complexity_score >= 5:
cost_factors["approximate_complexity"] = "high"
elif complexity_score >= 2:
cost_factors["approximate_complexity"] = "medium"
return cost_factors
def validate_databricks_config() -> Dict[str, Any]:
"""Validate Databricks configuration and provide helpful error messages."""
validation_result = {
"valid": True,
"errors": [],
"warnings": [],
"recommendations": []
}
# Check required environment variables
required_vars = ["DATABRICKS_HOST", "DATABRICKS_TOKEN"]
for var in required_vars:
if not os.getenv(var):
validation_result["valid"] = False
validation_result["errors"].append(f"Missing required environment variable: {var}")
# Check host format
host = os.getenv("DATABRICKS_HOST")
if host and not (host.startswith("https://") or host.startswith("http://")):
validation_result["warnings"].append(
"DATABRICKS_HOST should include the protocol (https://)"
)
# Check for optional but recommended variables
if not os.getenv("DATABRICKS_WAREHOUSE_ID"):
validation_result["recommendations"].append(
"Consider setting DATABRICKS_WAREHOUSE_ID for better query performance"
)
return validation_result
def create_sample_config() -> Dict[str, Any]:
"""Create a sample configuration file content."""
return {
"databricks_host": "https://your-workspace.cloud.databricks.com",
"databricks_token": "your-access-token",
"databricks_warehouse_id": "your-warehouse-id",
"log_level": "INFO",
"server_config": {
"name": "databricks-mcp-server",
"version": "0.1.0"
},
"cache_config": {
"enable_caching": True,
"cache_ttl_seconds": 3600
}
}
def get_timestamp() -> str:
"""Get current timestamp in ISO format."""
return datetime.now(timezone.utc).isoformat()
def safe_json_dump(obj: Any, indent: int = 2) -> str:
"""Safely dump object to JSON string."""
try:
return json.dumps(obj, indent=indent, default=str, ensure_ascii=False)
except Exception as e:
return f"Error serializing object: {e}"