Snowflake MCP Service

by davidamom
Verified
#!/usr/bin/env python import os import asyncio import logging import json import time import snowflake.connector from dotenv import load_dotenv import mcp.server.stdio from mcp.server import Server from mcp.types import Tool, ServerResult, TextContent from contextlib import closing from typing import Optional, Any, List, Dict # Configure logging logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('snowflake_mcp') # Load environment variables from .env file load_dotenv() class SnowflakeConnection: """ Snowflake database connection management class """ def __init__(self): # Initialize configuration from environment variables self.config = { "user": os.getenv("SNOWFLAKE_USER"), "account": os.getenv("SNOWFLAKE_ACCOUNT"), "database": os.getenv("SNOWFLAKE_DATABASE"), "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"), } # Determine authentication method private_key_file = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE") # Priority 1: Key pair authentication if file is provided and exists if private_key_file and os.path.exists(private_key_file): # Check if using passphrase or not passphrase = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE") if passphrase: logger.info("Using key pair authentication with passphrase") else: logger.info("Using key pair authentication without passphrase") # Try to setup key pair authentication auth_success = self._setup_key_pair_auth(private_key_file, passphrase) # If key pair auth failed, fall back to password if not auth_success: logger.info("Falling back to password authentication") password = os.getenv("SNOWFLAKE_PASSWORD") if password: self.config["password"] = password else: logger.error("No password provided as fallback. Authentication will likely fail.") else: # Priority 2: Password authentication password = os.getenv("SNOWFLAKE_PASSWORD") if password: self.config["password"] = password logger.info("Using password authentication") else: logger.error("No authentication method configured. Please provide either a private key or password.") self.conn: Optional[snowflake.connector.SnowflakeConnection] = None # Log config (excluding sensitive info) safe_config = {k: v for k, v in self.config.items() if k not in ['password', 'private_key', 'private_key_passphrase']} logger.info(f"Initialized with config: {json.dumps(safe_config)}") def _setup_key_pair_auth(self, private_key_file: str, passphrase: str = None) -> bool: """ Set up key pair authentication Args: private_key_file (str): Path to private key file passphrase (str, optional): Passphrase for the private key Returns: bool: True if key pair authentication was set up successfully, False otherwise """ try: # Read private key file with open(private_key_file, "rb") as key_file: private_key = key_file.read() # Try to load the key using snowflake's recommended approach from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_private_key logger.info(f"Loading private key from {private_key_file}") # Use passphrase only if provided p_key = load_pem_private_key( private_key, password=passphrase.encode() if passphrase else None, backend=default_backend() ) # Convert key to DER format from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption pkb = p_key.private_bytes( encoding=Encoding.DER, format=PrivateFormat.PKCS8, encryption_algorithm=NoEncryption() ) # Add to config (this is what Snowflake expects) self.config["private_key"] = pkb # If we had a passphrase, add it to config if passphrase: self.config["private_key_passphrase"] = passphrase logger.info("Private key loaded successfully") return True except Exception as e: logger.error(f"Error setting up key pair authentication: {str(e)}") logger.error("Details:", exc_info=True) return False def ensure_connection(self) -> snowflake.connector.SnowflakeConnection: """ Ensure database connection is available, create new connection if it doesn't exist or is disconnected """ try: # Check if connection needs to be re-established if self.conn is None: logger.info("Creating new Snowflake connection...") self.conn = snowflake.connector.connect( **self.config, client_session_keep_alive=True, network_timeout=15, login_timeout=15 ) self.conn.cursor().execute("ALTER SESSION SET TIMEZONE = 'UTC'") logger.info("New connection established and configured") # Test if connection is valid try: self.conn.cursor().execute("SELECT 1") except: logger.info("Connection lost, reconnecting...") self.conn = None return self.ensure_connection() return self.conn except Exception as e: logger.error(f"Connection error: {str(e)}") raise def execute_query(self, query: str) -> List[Dict[str, Any]]: """ Execute SQL query and return results Args: query (str): SQL query statement Returns: List[Dict[str, Any]]: List of query results """ start_time = time.time() logger.info(f"Executing query: {query[:200]}...") # Log only first 200 characters try: conn = self.ensure_connection() with conn.cursor() as cursor: # For write operations use transaction if any(query.strip().upper().startswith(word) for word in ['INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER']): cursor.execute("BEGIN") try: cursor.execute(query) conn.commit() logger.info(f"Write query executed in {time.time() - start_time:.2f}s") return [{"affected_rows": cursor.rowcount}] except Exception as e: conn.rollback() raise else: # Read operations cursor.execute(query) if cursor.description: columns = [col[0] for col in cursor.description] rows = cursor.fetchall() results = [dict(zip(columns, row)) for row in rows] logger.info(f"Read query returned {len(results)} rows in {time.time() - start_time:.2f}s") return results return [] except snowflake.connector.errors.ProgrammingError as e: logger.error(f"SQL Error: {str(e)}") logger.error(f"Error Code: {getattr(e, 'errno', 'unknown')}") raise except Exception as e: logger.error(f"Query error: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise def close(self): """ Close database connection """ if self.conn: try: self.conn.close() logger.info("Connection closed") except Exception as e: logger.error(f"Error closing connection: {str(e)}") finally: self.conn = None class SnowflakeMCPServer(Server): """ Snowflake MCP server class, handles client interactions """ def __init__(self): super().__init__(name="snowflake-mcp-server") self.db = SnowflakeConnection() logger.info("SnowflakeMCPServer initialized") @self.list_tools() async def handle_tools(): """ Return list of available tools """ return [ Tool( name="execute_query", description="Execute a SQL query on Snowflake", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL query to execute" } }, "required": ["query"] } ) ] @self.call_tool() async def handle_call_tool(name: str, arguments: dict): """ Handle tool call requests Args: name (str): Tool name arguments (dict): Tool arguments Returns: list[TextContent]: Execution results """ if name == "execute_query": start_time = time.time() try: result = self.db.execute_query(arguments["query"]) execution_time = time.time() - start_time return [TextContent( type="text", text=f"Results (execution time: {execution_time:.2f}s):\n{result}" )] except Exception as e: error_message = f"Error executing query: {str(e)}" logger.error(error_message) return [TextContent( type="text", text=error_message )] def __del__(self): """ Clean up resources, close database connection """ if hasattr(self, 'db'): self.db.close() async def main(): """ Main function, starts server and handles requests """ try: server = SnowflakeMCPServer() initialization_options = server.create_initialization_options() logger.info("Starting server") async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, initialization_options ) except Exception as e: logger.critical(f"Server failed: {str(e)}", exc_info=True) raise finally: logger.info("Server shutting down") if __name__ == "__main__": asyncio.run(main())