main.py.backup•2.87 kB
from typing import Optional
from mcp.server.fastmcp import FastMCP
from dbapi import execute_statement
from databricks_formatter import format_query_results
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
# Configuration constants
DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST", "")
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN", "")
DATABRICKS_SQL_WAREHOUSE_ID = os.environ.get("DATABRICKS_SQL_WAREHOUSE_ID", "")
mcp = FastMCP("databricks")
@mcp.tool()
async def execute_sql_query(sql: str, warehouse_id: Optional[str] = None) -> str:
    """Execute a SQL query on Databricks and return the results.
    
    Args:
        sql: The SQL query to execute
        warehouse_id: Optional warehouse ID to use (defaults to DATABRICKS_SQL_WAREHOUSE_ID environment variable)
    """
    try:
        result = await execute_statement(sql, warehouse_id)
        return format_query_results(result)
    except Exception as e:
        return f"Error executing SQL query: WAREHOUSE USED IS : {warehouse_id}{str(e)}"
@mcp.tool()
async def list_schemas(catalog: str, warehouse_id: Optional[str] = None) -> str:
    """List all available schemas in a Databricks catalog.
    Args:
        catalog: The catalog name to list schemas from
        warehouse_id: Optional warehouse ID to use (defaults to DATABRICKS_SQL_WAREHOUSE_ID environment variable)
    """
    sql = f"SHOW SCHEMAS IN {catalog}"
    try:
        result = await execute_statement(sql, warehouse_id)
        return format_query_results(result)
    except Exception as e:
        return f"Error listing schemas: {str(e)}"
@mcp.tool()
async def list_tables(schema: str, warehouse_id: Optional[str] = None) -> str:
    """List all tables in a specific schema.
    
    Args:
        schema: The schema name to list tables from
        warehouse_id: Optional warehouse ID to use (defaults to DATABRICKS_SQL_WAREHOUSE_ID environment variable)
    """
    sql = f"SHOW TABLES IN {schema}"
    try:
        result = await execute_statement(sql, warehouse_id)
        return format_query_results(result)
    except Exception as e:
        return f"Error listing tables: {str(e)}"
@mcp.tool()
async def describe_table(table_name: str, warehouse_id: Optional[str] = None) -> str:
    """Describe a table's schema.
    
    Args:
        table_name: The fully qualified table name (e.g., schema.table_name)
        warehouse_id: Optional warehouse ID to use (defaults to DATABRICKS_SQL_WAREHOUSE_ID environment variable)
    """
    sql = f"DESCRIBE TABLE {table_name}"
    try:
        result = await execute_statement(sql, warehouse_id)
        return format_query_results(result)
    except Exception as e:
        return f"Error describing table: {str(e)}"
if __name__ == "__main__":
    # Initialize and run the server
    mcp.run(transport='stdio')