db_client.py•6.68 kB
import os, logging, asyncio
import json
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
from datetime import datetime, timedelta
from cryptography.hazmat.primitives.serialization import load_pem_private_key, Encoding, PrivateFormat, NoEncryption
from cryptography.hazmat.backends import default_backend
from snowflake.snowpark import Session
from snowflake.snowpark.exceptions import SnowparkSQLException
logger = logging.getLogger("aai_mcp_snowflake_server.db")
class SnowflakeDB:
def __init__(self, cfg: dict):
self.cfg = cfg
self.session: Session | None = None
self._query_cache = {} # Simple query cache
self._cache_ttl = 300 # 5 minutes default TTL
def _build_connection_options(self):
base = {
"user": self.cfg["user"],
"account": self.cfg["account"],
"role": self.cfg.get("role"),
"warehouse": self.cfg.get("warehouse"),
"database": self.cfg.get("database"),
"schema": self.cfg.get("schema"),
"login_timeout": self.cfg.get("login_timeout"),
"network_timeout": self.cfg.get("network_timeout"),
"host": self.cfg.get("host"),
"port": self.cfg.get("port"),
}
if self.cfg.get("private_key_path"):
key_path = Path(self.cfg["private_key_path"])
if not key_path.exists():
raise FileNotFoundError(f"Private key not found: {key_path}")
with key_path.open("rb") as f:
p_key = load_pem_private_key(
f.read(),
password=(self.cfg.get("private_key_passphrase") or "").encode() or None,
backend=default_backend(),
)
private_key_bytes = p_key.private_bytes(
encoding=Encoding.DER,
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
)
base["private_key"] = private_key_bytes
# For key auth (Snowpark auto-detects if private_key present)
else:
if not self.cfg.get("password"):
raise ValueError("Either password or private key auth must be provided.")
base["password"] = self.cfg["password"]
# Trim None
return {k: v for k, v in base.items() if v is not None}
def connect(self):
if self.session:
return self.session
opts = self._build_connection_options()
logger.info("Connecting to Snowflake account=%s role=%s wh=%s",
opts.get("account"), opts.get("role"), opts.get("warehouse"))
self.session = Session.builder.configs(opts).create()
return self.session
async def ensure(self):
return await asyncio.to_thread(self.connect)
def _get_cache_key(self, sql: str, params: Dict = None) -> str:
"""Generate cache key for SQL query"""
import hashlib
key_str = f"{sql}:{json.dumps(params or {}, sort_keys=True)}"
return hashlib.md5(key_str.encode()).hexdigest()
def _is_cache_valid(self, timestamp: datetime) -> bool:
"""Check if cached result is still valid"""
return (datetime.now() - timestamp).total_seconds() < self._cache_ttl
def _should_cache_query(self, sql: str) -> bool:
"""Determine if query results should be cached"""
sql_upper = sql.upper().strip()
# Cache SELECT queries for metadata, schemas, tables, etc.
cacheable_patterns = [
"SHOW DATABASES",
"SHOW SCHEMAS",
"SHOW TABLES",
"SHOW VIEWS",
"DESCRIBE TABLE",
"SELECT * FROM INFORMATION_SCHEMA",
"SELECT * FROM AZDMAND.SDW_DPE_DASH_DB.AAI_MD",
"SELECT * FROM AZDMAND.SDW_DPE_DASH_DB.AAI_ACCESS"
]
return any(pattern in sql_upper for pattern in cacheable_patterns)
async def run_sql(self, sql: str, to_pandas: bool = False, use_cache: bool = True) -> List[Any]:
"""
Execute SQL query with optional caching and enhanced error handling.
Args:
sql: SQL query string
to_pandas: Return as pandas DataFrame if True
use_cache: Use query caching if True
Returns:
Query results as list of Row objects or pandas DataFrame
"""
# Check cache first
if use_cache and self._should_cache_query(sql):
cache_key = self._get_cache_key(sql)
if cache_key in self._query_cache:
cached_result, timestamp = self._query_cache[cache_key]
if self._is_cache_valid(timestamp):
logger.debug(f"Cache hit for query: {sql[:50]}...")
return cached_result
sess = await self.ensure()
def _exec():
try:
df = sess.sql(sql)
result = df.to_pandas() if to_pandas else df.collect()
# Cache result if applicable
if use_cache and self._should_cache_query(sql):
cache_key = self._get_cache_key(sql)
self._query_cache[cache_key] = (result, datetime.now())
logger.debug(f"Cached result for query: {sql[:50]}...")
return result
except SnowparkSQLException as e:
logger.error(f"Snowpark SQL error: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error executing SQL: {e}")
raise
return await asyncio.to_thread(_exec)
async def list_databases(self):
rows = await self.run_sql("SHOW DATABASES")
# Each row is a Snowpark Row object; name is the database name column
return [row['name'] for row in rows]
async def list_schemas(self, database: str | None = None):
if database:
return [r[1] for r in await self.run_sql(f"SHOW TERSE SCHEMAS IN DATABASE {database}")]
return [r[1] for r in await self.run_sql("SHOW TERSE SCHEMAS")]
async def list_tables(self, database: str | None = None, schema: str | None = None):
scope = ""
if database and schema:
scope = f" IN {database}.{schema}"
elif database:
scope = f" IN DATABASE {database}"
return [r[1] for r in await self.run_sql(f"SHOW TERSE TABLES{scope}")]
async def query(self, sql: str):
return await self.run_sql(sql, to_pandas=False)
async def close(self):
if self.session:
await asyncio.to_thread(self.session.close)
self.session = None