databricks_manager.py•10.7 kB
import os
import time
import logging
from dotenv import load_dotenv
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import DatabricksError
from databricks.sdk.service import jobs, pipelines, workspace
from typing import Dict, Any, Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
class DatabricksManager:
"""
Handles Databricks operations with robust error handling and logging.
Supports code execution, job management, and DLT pipeline operations.
"""
def __init__(self):
"""Initialize Databricks workspace client"""
self.host = os.getenv("DATABRICKS_HOST")
self.token = os.getenv("DATABRICKS_TOKEN")
if not self.host or not self.token:
raise RuntimeError("DATABRICKS_HOST and DATABRICKS_TOKEN must be set in environment variables.")
self.client = WorkspaceClient(host=self.host, token=self.token)
self.max_retries = 3
self.retry_delay = 30 # seconds
def submit_code(self, code: str, cluster_id: str) -> Dict[str, Any]:
"""
Submit code to a Databricks cluster for execution.
Args:
code: Python code to execute
cluster_id: Target cluster ID
Returns:
Dict with execution results and status
"""
try:
logger.info(f"Submitting code to cluster {cluster_id}")
result = self.client.command_execution.execute(
cluster_id=cluster_id,
language="python",
command=code
)
logger.info("Code submitted successfully")
return result
except DatabricksError as e:
logger.error(f"Databricks error in submit_code: {str(e)}")
raise Exception(f"Databricks error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in submit_code: {str(e)}")
raise Exception(f"Execution error: {str(e)}")
def create_job(self, job_config: Dict[str, Any]) -> Dict[str, Any]:
"""
Create a new Databricks job using the REST API (robust for dict configs).
Args:
job_config: Job configuration dictionary
Returns:
Dict with job details
"""
try:
import requests
import json
headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
}
create_url = f"{self.host.rstrip('/')}/api/2.1/jobs/create"
response = requests.post(create_url, headers=headers, json=job_config)
if response.status_code == 200:
job_info = response.json()
job_id = job_info.get("job_id")
if not job_id:
raise Exception(f"No job_id in response: {job_info}")
logger.info(f"Job created successfully with ID: {job_id}")
return job_info
else:
raise Exception(f"Failed to create job: {response.text}")
except Exception as e:
logger.error(f"Unexpected error in create_job: {str(e)}")
raise Exception(f"Job creation error: {str(e)}")
def run_job(self, job_id: str) -> Dict[str, Any]:
"""
Run an existing Databricks job.
Args:
job_id: Job ID to run
Returns:
Dict with run details
"""
try:
logger.info(f"Starting job {job_id}")
run = self.client.jobs.run_now(job_id=job_id)
run_id = getattr(run, 'run_id', None)
if not run_id:
raise Exception("Failed to get run_id from job execution")
logger.info(f"Job started successfully: Job ID {job_id}, Run ID {run_id}")
return run
except DatabricksError as e:
logger.error(f"Databricks error in run_job: {str(e)}")
raise Exception(f"Databricks error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in run_job: {str(e)}")
raise Exception(f"Job execution error: {str(e)}")
def create_dlt_pipeline(self, pipeline_config: Dict[str, Any]) -> Dict[str, Any]:
"""
Create a Delta Live Tables pipeline.
Args:
pipeline_config: Pipeline configuration dictionary
Returns:
Dict with pipeline details
"""
try:
logger.info(f"Creating DLT pipeline with config: {pipeline_config.get('name', 'unnamed')}")
pipeline = self.client.pipelines.create(**pipeline_config)
pipeline_id = getattr(pipeline, 'pipeline_id', None)
if not pipeline_id:
raise Exception("Failed to get pipeline_id from pipeline creation")
logger.info(f"DLT pipeline created successfully with ID: {pipeline_id}")
return pipeline
except DatabricksError as e:
logger.error(f"Databricks error in create_dlt_pipeline: {str(e)}")
raise Exception(f"Databricks error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in create_dlt_pipeline: {str(e)}")
raise Exception(f"Pipeline creation error: {str(e)}")
def get_job_error(self, run_id: str) -> Optional[str]:
"""
Get error details for a failed job run.
Args:
run_id: Run ID to check
Returns:
Error message if job failed, None otherwise
"""
try:
logger.info(f"Checking job error for run {run_id}")
run = self.client.jobs.get_run(run_id=run_id)
if run.state and run.state.result_state == "FAILED":
error_msg = run.state.state_message
logger.info(f"Job failed with error: {error_msg}")
return error_msg
else:
logger.info("Job did not fail")
return None
except DatabricksError as e:
logger.error(f"Databricks error in get_job_error: {str(e)}")
raise Exception(f"Databricks error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in get_job_error: {str(e)}")
raise Exception(f"Error retrieval error: {str(e)}")
def check_job_status(self, job_id: str, run_id: str) -> Dict[str, Any]:
"""
Check the status of a job run.
Args:
job_id: Job ID
run_id: Run ID
Returns:
Dict with job status details
"""
try:
logger.info(f"Checking status for job {job_id}, run {run_id}")
run = self.client.jobs.get_run(run_id=run_id)
status_info = {
"job_id": job_id,
"run_id": run_id,
"state": getattr(run.state, 'life_cycle_state', 'UNKNOWN') if run.state else 'UNKNOWN',
"result_state": getattr(run.state, 'result_state', 'UNKNOWN') if run.state else 'UNKNOWN',
"state_message": getattr(run.state, 'state_message', '') if run.state else '',
"start_time": getattr(run.start_time, 'timestamp', 0) if run.start_time else 0,
"end_time": getattr(run.end_time, 'timestamp', 0) if run.end_time else 0
}
logger.info(f"Job status: {status_info['state']} - {status_info['result_state']}")
return status_info
except DatabricksError as e:
logger.error(f"Databricks error in check_job_status: {str(e)}")
raise Exception(f"Databricks error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in check_job_status: {str(e)}")
raise Exception(f"Status check error: {str(e)}")
def upload_notebook(self, code: str, workspace_path: str) -> str:
"""
Upload code as a notebook to Databricks workspace using REST API.
Args:
code: Python code to upload
workspace_path: Path in Databricks workspace
Returns:
The workspace path where the notebook was uploaded
"""
try:
import base64
import requests
# Prepare headers for REST API
headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
}
# Convert code to base64
notebook_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8")
# Upload payload using REST API approach
upload_payload = {
"path": workspace_path,
"language": "PYTHON",
"overwrite": True,
"content": notebook_b64,
"format": "SOURCE"
}
upload_url = f"{self.host.rstrip('/')}/api/2.0/workspace/import"
logger.info(f"Uploading notebook to: {workspace_path}")
logger.info(f"Using URL: {upload_url}")
response = requests.post(upload_url, headers=headers, json=upload_payload)
if response.status_code == 200:
logger.info(f"✓ Notebook uploaded successfully: {workspace_path}")
return workspace_path
else:
raise Exception(f"Failed to upload notebook: {response.text}")
except Exception as e:
logger.error(f"Failed to upload notebook: {str(e)}")
raise Exception(f"Failed to upload notebook: {str(e)}")
# Global instance for backward compatibility
_manager = None
def get_manager():
"""Get or create a global DatabricksManager instance"""
global _manager
if _manager is None:
_manager = DatabricksManager()
return _manager
# Backward compatibility functions
def submit_code(code: str, cluster_id: str):
return get_manager().submit_code(code, cluster_id)
def create_job(job_config: dict):
return get_manager().create_job(job_config)
def run_job(job_id: str):
return get_manager().run_job(job_id)
def create_dlt_pipeline(pipeline_config: dict):
return get_manager().create_dlt_pipeline(pipeline_config)
def get_job_error(run_id: str):
return get_manager().get_job_error(run_id)