"""Database connection management."""
import os
import hashlib
from typing import Optional
from urllib.parse import quote_plus
from .credentials import Credentials
from .dependencies import ensure_deps_once
try:
from sqlalchemy import create_engine, text, Engine
from sqlalchemy.engine.url import make_url
except ImportError:
ensure_deps_once()
from sqlalchemy import create_engine, text, Engine
from sqlalchemy.engine.url import make_url
class ConnectionManager:
"""Manages SQLAlchemy engine creation and caching."""
CONNECTION_ATTEMPTS = [
{"name": "TrustServerCertificate=yes", "params": "TrustServerCertificate=yes&timeout=30&login_timeout=30"},
{"name": "Encrypt=no", "params": "Encrypt=no&timeout=30&login_timeout=30"},
{"name": "Encrypt=yes + TrustServerCertificate=yes", "params": "Encrypt=yes&TrustServerCertificate=yes&timeout=30&login_timeout=30"},
{"name": "Basic (no encryption params)", "params": "timeout=30&login_timeout=30"}
]
@staticmethod
def _detect_db_type(driver: Optional[str]) -> str:
"""Detect database type from driver name."""
if not driver:
return "mssql" # Default to SQL Server for backward compatibility
driver_lower = driver.lower()
if "mysql" in driver_lower:
return "mysql"
elif "postgres" in driver_lower or "postgresql" in driver_lower:
return "postgresql"
elif "sql server" in driver_lower or "mssql" in driver_lower or "odbc" in driver_lower:
return "mssql"
else:
# Default to mssql if unknown
return "mssql"
def __init__(self):
self.env_dsn: Optional[str] = os.getenv("MCP_SQL_DSN")
self.env_name: Optional[str] = os.getenv("MCP_SQL_NAME")
self.env_user: Optional[str] = os.getenv("MCP_SQL_USER")
self.env_password: Optional[str] = os.getenv("MCP_SQL_PASSWORD")
self.env_server: Optional[str] = os.getenv("MCP_SQL_SERVER")
self.env_database: Optional[str] = os.getenv("MCP_SQL_DATABASE")
self.env_driver: Optional[str] = os.getenv("MCP_SQL_DRIVER", "ODBC Driver 18 for SQL Server")
self._dsn_url = self._build_dsn_url()
self._engines: dict[tuple, Engine] = {}
def _build_dsn_url(self):
"""Build DSN URL from environment variables."""
if self.env_dsn:
try:
return make_url(self.env_dsn)
except Exception as e:
print(f"Error parsing MCP_SQL_DSN: {e}")
return None
if self.env_user and self.env_password and self.env_server:
try:
user_encoded = quote_plus(self.env_user)
password_encoded = quote_plus(self.env_password)
db_name = self.env_database or "master"
dsn_string = (
f"mssql+pyodbc://{user_encoded}:{password_encoded}@{self.env_server}:1433/{db_name}"
f"?driver={self.env_driver.replace(' ', '+')}&TrustServerCertificate=yes"
)
print(f"Built DSN from components: Server={self.env_server}, Database={db_name}")
return make_url(dsn_string)
except Exception as e:
print(f"Error building DSN from components: {e}")
return None
def get_engine_with_credentials(self, creds: Credentials) -> Optional[Engine]:
"""Create or retrieve an engine with dynamic credentials."""
self._log_connection_attempt(creds)
if not creds.is_valid():
print(f"❌ Error: Missing required connection parameters")
return None
# Create unique cache key
password_hash = hashlib.sha256(creds.password.encode()).hexdigest()[:16]
engine_key = (creds.server, creds.database, creds.user, password_hash, creds.driver, creds.port)
if engine_key not in self._engines:
engine = self._create_engine(creds)
if engine:
self._engines[engine_key] = engine
return engine
return self._engines[engine_key]
def _log_connection_attempt(self, creds: Credentials):
"""Log connection attempt details."""
db_type = self._detect_db_type(creds.driver)
print(f"🔍 Connection request received:")
print(f" Database Type: {db_type} (detected from driver)")
print(f" Server: {creds.server}")
print(f" Port: {creds.port}")
print(f" Database: {creds.database}")
print(f" User: {creds.user}")
print(f" Password: {creds.mask_password()}")
print(f" Driver: {creds.driver or '(not specified, will use default)'}")
def _create_engine(self, creds: Credentials) -> Optional[Engine]:
"""Create a new engine with connection retry logic."""
db_type = self._detect_db_type(creds.driver)
user_encoded = quote_plus(creds.user)
password_encoded = quote_plus(creds.password)
# Build connection string based on database type
if db_type == "mysql":
return self._create_mysql_engine(creds, user_encoded, password_encoded)
elif db_type == "postgresql":
return self._create_postgresql_engine(creds, user_encoded, password_encoded)
else: # mssql (SQL Server)
return self._create_mssql_engine(creds, user_encoded, password_encoded)
def _create_mysql_engine(self, creds: Credentials, user_encoded: str, password_encoded: str) -> Optional[Engine]:
"""Create MySQL engine."""
try:
# MySQL connection string format: mysql+mysqlconnector://user:password@host:port/database
dsn_string = (
f"mysql+mysqlconnector://{user_encoded}:{password_encoded}@{creds.server}:{creds.port}/{creds.database}"
)
url = make_url(dsn_string)
engine = create_engine(url)
# Test connection
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
print(f"✅ MySQL connection successful")
print(f" Server: {creds.server}:{creds.port}/{creds.database} (user: {creds.user})")
return engine
except Exception as e:
error_str = str(e)
print(f"❌ MySQL connection failed: {error_str[:150]}")
return None
def _create_postgresql_engine(self, creds: Credentials, user_encoded: str, password_encoded: str) -> Optional[Engine]:
"""Create PostgreSQL engine."""
try:
# PostgreSQL connection string format: postgresql+psycopg2://user:password@host:port/database
dsn_string = (
f"postgresql+psycopg2://{user_encoded}:{password_encoded}@{creds.server}:{creds.port}/{creds.database}"
)
url = make_url(dsn_string)
engine = create_engine(url)
# Test connection
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
print(f"✅ PostgreSQL connection successful")
print(f" Server: {creds.server}:{creds.port}/{creds.database} (user: {creds.user})")
return engine
except Exception as e:
error_str = str(e)
print(f"❌ PostgreSQL connection failed: {error_str[:150]}")
return None
def _create_mssql_engine(self, creds: Credentials, user_encoded: str, password_encoded: str) -> Optional[Engine]:
"""Create SQL Server engine with connection retry logic."""
driver_encoded = (creds.driver or "ODBC Driver 18 for SQL Server").replace(' ', '+')
for attempt in self.CONNECTION_ATTEMPTS:
try:
dsn_string = (
f"mssql+pyodbc://{user_encoded}:{password_encoded}@{creds.server}:{creds.port}/{creds.database}"
f"?driver={driver_encoded}&{attempt['params']}"
)
url = make_url(dsn_string)
engine = create_engine(url)
# Test connection
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
print(f"✅ SQL Server connection successful with: {attempt['name']}")
print(f" Server: {creds.server}:{creds.port}/{creds.database} (user: {creds.user})")
return engine
except Exception as e:
error_str = str(e)
print(f"❌ Attempt failed ({attempt['name']}): {error_str[:150]}")
if "18456" in error_str:
print(f" 💡 Authentication error - check credentials")
elif "HYT00" in error_str or "timeout" in error_str.lower():
print(f" 💡 Timeout - trying next configuration...")
print(f"❌ All connection attempts failed for {creds.server}/{creds.database}")
return None
def get_engine(self, server_name: str, database: str) -> Optional[Engine]:
"""Get engine using environment configuration."""
if self._dsn_url is None:
print(f"Error: No environment configuration found.")
return None
env_name = self.env_name or (self._dsn_url.host or "env_dsn")
engine_key = (env_name, database)
if engine_key not in self._engines:
try:
target_url = self._dsn_url.set(database=database)
self._engines[engine_key] = create_engine(target_url)
print(f"Created engine (ENV) for: {env_name}, database: {database}")
except Exception as e:
print(f"Error creating engine: {e}")
return None
return self._engines[engine_key]
def get_configured_server_names(self) -> list[str]:
"""List configured server names."""
if self._dsn_url is not None:
return [self.env_name or (self._dsn_url.host or "env_server")]
return []