sql_service.py•8.94 kB
"""
SQL Service for Databricks operations.
This service handles:
- Database connection management
- Query execution 
- Result formatting for language models
- Error handling
"""
import os
from typing import List, Dict, Any, Optional
from databricks.sdk.core import Config  
from databricks import sql
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class SQLService:
    """Service for executing SQL queries and formatting results for language models."""
    
    def __init__(self, config: Optional[Config] = None):
        """Initialize the SQL service with Databricks configuration."""
        self.config = config or Config()
        self.warehouse_id = os.getenv('DATABRICKS_WAREHOUSE_ID')
        
        if not self.warehouse_id:
            raise ValueError("DATABRICKS_WAREHOUSE_ID environment variable is required")
    
    def execute_query(self, query: str) -> List[tuple]:
        """
        Execute a SQL query and return raw results.
        
        Args:
            query: SQL query string to execute
            
        Returns:
            List of tuples containing query results
            
        Raises:
            Exception: If query execution fails
        """
        try:
            with sql.connect(
                server_hostname=self.config.host,
                http_path=f"/sql/1.0/warehouses/{self.warehouse_id}",
                credentials_provider=lambda: self.config.authenticate,
            ) as connection:
                with connection.cursor() as cursor:
                    cursor.execute(query)
                    return cursor.fetchall()
        except Exception as e:
            raise Exception(f"Error querying Databricks Warehouse: {e}")
    
    def execute_query_with_formatting(self, query: str, formatter_func) -> str:
        """
        Execute a SQL query and format results using a custom formatter function.
        
        Args:
            query: SQL query string to execute
            formatter_func: Function that takes query results and returns formatted string
            
        Returns:
            Formatted string output suitable for language models
        """
        try:
            results = self.execute_query(query)
            
            if not results:
                return "No data found for the specified query"
            
            return formatter_func(results)
        except Exception as e:
            return str(e)
class QueryFormatter:
    """Utility class for formatting SQL query results into text for language models."""
    
    @staticmethod
    def format_table_formats(results: List[tuple]) -> str:
        """Format table format query results."""
        output = "Table formats:\n"
        for row in results:
            format_name = row[0] if row[0] else "Unknown"
            table_count = row[1]
            output += f"{format_name}: {table_count} tables\n"
        return output.strip()
    
    @staticmethod
    def format_table_types_distribution(results: List[tuple]) -> str:
        """Format table types distribution query results."""
        output = "Table types distribution:\n"
        for row in results:
            table_type = row[0] if row[0] else "Unknown"
            percentage = row[1]
            output += f"{table_type}: {percentage}%\n"
        return output.strip()
    
    @staticmethod
    def format_jobs_on_all_purpose_clusters(results: List[tuple]) -> str:
        """Format jobs running on all-purpose clusters query results."""
        output = "Jobs running on all-purpose clusters (last 30 days):\n"
        for row in results:
            workspace_id = row[0] if row[0] else "Unknown"
            job_id = row[1] if row[1] else "Unknown"
            cluster_id = row[2] if row[2] else "Unknown"
            cluster_name = row[3] if row[3] else "Unknown"
            owned_by = row[4] if row[4] else "Unknown"
            output += f"Job {job_id} on cluster '{cluster_name}' (ID: {cluster_id}) - Owner: {owned_by}\n"
        return output.strip()
    
    @staticmethod
    def format_sql_vs_all_purpose(results: List[tuple]) -> str:
        """Format SQL vs All Purpose compute usage query results."""
        output = "SQL vs All Purpose compute usage (last 30 days):\n"
        for row in results:
            billing_product = row[0] if row[0] else "Unknown"
            dbu_usage = row[1] if row[1] else 0
            output += f"{billing_product}: {dbu_usage:.2f} DBU\n"
        return output.strip()
    
    @staticmethod
    def format_dbr_versions(results: List[tuple]) -> str:
        """Format Databricks Runtime versions query results."""
        output = "Databricks Runtime versions in use:\n"
        for row in results:
            version = row[0] if row[0] else "Unknown"
            count = row[1] if row[1] else 0
            output += f"DBR {version}: {count} clusters\n"
        return output.strip()
    
    @staticmethod
    def format_serverless_percentage(results: List[tuple]) -> str:
        """Format serverless compute percentage query results."""
        serverless_percent = results[0][0] if results and results[0][0] else 0
        return f"Serverless compute usage: {serverless_percent:.2f}% of total compute (last 28 days)"
    
    @staticmethod
    def format_sql_compute_costs(results: List[tuple]) -> str:
        """Format SQL compute costs by type query results."""
        output = "SQL compute costs by type (last 30 days):\n"
        for row in results:
            sql_type = row[0] if row[0] else "Unknown"
            cost = row[1] if row[1] else 0
            output += f"{sql_type}: ${cost:.2f}\n"
        return output.strip()
    
    @staticmethod
    def format_cluster_utilization(results: List[tuple]) -> str:
        """Format cluster utilization query results."""
        cpu_p75 = results[0][0] if results and results[0][0] else 0
        memory_p75 = results[0][1] if results and results[0][1] else 0
        return f"Cluster utilization (75th percentile, last 28 days):\nCPU: {cpu_p75:.2f}%\nMemory: {memory_p75:.2f}%"
    
    @staticmethod
    def format_autoscaling_percentage(results: List[tuple]) -> str:
        """Format autoscaling percentage query results."""
        autoscaling_percent = results[0][0] if results and results[0][0] else 0
        return f"Clusters with autoscaling enabled: {autoscaling_percent:.2f}%"
    
    @staticmethod
    def format_auto_termination_analysis(results: List[tuple]) -> str:
        """Format auto-termination analysis query results."""
        row = results[0]
        p75_minutes = row[0] if row[0] else "N/A"
        max_minutes = row[1] if row[1] else "N/A"
        without_auto = row[2] if row[2] else 0
        with_auto = row[3] if row[3] else 0
        percent_without = row[4] if row[4] else 0
        
        output = f"Auto-termination analysis:\n"
        output += f"75th percentile auto-termination: {p75_minutes} minutes\n"
        output += f"Max auto-termination: {max_minutes} minutes\n"
        output += f"Clusters without auto-termination: {without_auto} ({percent_without:.1f}%)\n"
        output += f"Clusters with auto-termination: {with_auto}"
        return output
    
    @staticmethod
    def format_billing_table_access(results: List[tuple]) -> str:
        """Format billing table access query results."""
        output = "Daily billing table access (last 90 days):\n"
        for row in results[:10]:  # Limit to first 10 days for readability
            event_date = row[0] if row[0] else "Unknown"
            usage_count = row[1] if row[1] else 0
            output += f"{event_date}: {usage_count} reads\n"
        
        if len(results) > 10:
            output += f"... and {len(results) - 10} more days"
        
        return output.strip()
    
    @staticmethod
    def format_cluster_tagging_distribution(results: List[tuple]) -> str:
        """Format cluster tagging distribution query results."""
        output = "Cluster tagging distribution:\n"
        for row in results:
            tag_count = row[0] if row[0] else 0
            cluster_count = row[1] if row[1] else 0
            output += f"{tag_count} tags: {cluster_count} clusters\n"
        return output.strip()
    
    @staticmethod
    def format_popular_tags(results: List[tuple]) -> str:
        """Format popular cluster tags query results."""
        output = "Most popular cluster tags:\n"
        for row in results[:10]:  # Limit to top 10 tags
            tag_name = row[0] if row[0] else "Unknown"
            percentage = row[1] if row[1] else 0
            output += f"{tag_name}: {percentage:.1f}% of clusters\n"
        return output.strip()
# Default singleton instance
_sql_service = None
def get_sql_service() -> SQLService:
    """Get or create the default SQL service instance."""
    global _sql_service
    if _sql_service is None:
        _sql_service = SQLService()
    return _sql_service