"""Credential management for database connections."""
import os
from dataclasses import dataclass
from typing import Optional
try:
from fastmcp import Context
except ImportError:
from .dependencies import ensure_deps_once
ensure_deps_once()
from fastmcp import Context
@dataclass
class Credentials:
"""Data class for database credentials."""
user: Optional[str] = None
password: Optional[str] = None
server: Optional[str] = None
database: Optional[str] = None
driver: Optional[str] = None
port: int = 1433
def is_valid(self) -> bool:
"""Check if credentials have minimum required fields."""
return bool(self.user and self.password and self.server)
def mask_password(self) -> str:
"""Return a masked version of the password for logging."""
if not self.password or len(self.password) <= 6:
return "***"
return f"{self.password[:3]}...{self.password[-3:]}"
class CredentialsManager:
"""Manages credential retrieval from multiple sources with priority."""
@staticmethod
def _auto_detect_port(creds: Credentials) -> Credentials:
"""Auto-detect port based on driver if port is default (1433)."""
# Only auto-detect if port is the default SQL Server port
if creds.port == 1433 and creds.driver:
driver_lower = creds.driver.lower()
if 'mysql' in driver_lower:
creds.port = 3306
elif 'postgres' in driver_lower or 'postgresql' in driver_lower:
creds.port = 5432
return creds
@staticmethod
def get_from_context(
ctx: Optional[Context],
user: Optional[str] = None,
password: Optional[str] = None,
server: Optional[str] = None,
database: Optional[str] = None,
driver: Optional[str] = None,
port: Optional[int] = None
) -> Credentials:
"""
Extract credentials with priority:
1. Function parameters (highest)
2. HTTP Headers (X-MCP-SQL-*)
3. Session environment variables
4. Server environment variables (lowest)
"""
creds = Credentials(user, password, server, database, driver, port)
# Try HTTP headers
creds = CredentialsManager._get_from_headers(ctx, creds)
# Try client params
creds = CredentialsManager._get_from_client_params(ctx, creds)
# Fallback to environment
creds = CredentialsManager._get_from_environment(creds)
# Auto-detect port based on driver
creds = CredentialsManager._auto_detect_port(creds)
return creds
@staticmethod
def _get_from_headers(ctx: Optional[Context], creds: Credentials) -> Credentials:
"""Extract credentials from HTTP headers."""
if not ctx or not hasattr(ctx, 'request_context'):
return creds
req_ctx = ctx.request_context
if not hasattr(req_ctx, 'request'):
return creds
headers = req_ctx.request.headers
creds.user = creds.user or headers.get('x-mcp-sql-user')
creds.password = creds.password or headers.get('x-mcp-sql-password')
creds.server = creds.server or headers.get('x-mcp-sql-server')
creds.database = creds.database or headers.get('x-mcp-sql-database')
creds.driver = creds.driver or headers.get('x-mcp-sql-driver')
port_header = headers.get('x-mcp-sql-port')
if port_header and creds.port is None:
try:
creds.port = int(port_header)
except ValueError:
pass
if creds.user or creds.password or creds.server:
print(f"🔑 Credentials loaded from HTTP headers (client: {headers.get('user-agent', 'unknown')})")
return creds
@staticmethod
def _get_from_client_params(ctx: Optional[Context], creds: Credentials) -> Credentials:
"""Extract credentials from client parameters."""
if not ctx or not hasattr(ctx, 'client_params'):
return creds
client_env = ctx.client_params.get('env', {}) if isinstance(ctx.client_params, dict) else {}
creds.user = creds.user or client_env.get('MCP_SQL_USER')
creds.password = creds.password or client_env.get('MCP_SQL_PASSWORD')
creds.server = creds.server or client_env.get('MCP_SQL_SERVER')
creds.database = creds.database or client_env.get('MCP_SQL_DATABASE')
creds.driver = creds.driver or client_env.get('MCP_SQL_DRIVER')
port_str = client_env.get('MCP_SQL_PORT')
if port_str and creds.port is None:
try:
creds.port = int(port_str)
except ValueError:
pass
return creds
@staticmethod
def _get_from_environment(creds: Credentials) -> Credentials:
"""Extract credentials from environment variables."""
creds.user = creds.user or os.getenv('MCP_SQL_USER')
creds.password = creds.password or os.getenv('MCP_SQL_PASSWORD')
creds.server = creds.server or os.getenv('MCP_SQL_SERVER')
creds.database = creds.database or os.getenv('MCP_SQL_DATABASE')
creds.driver = creds.driver or os.getenv('MCP_SQL_DRIVER', 'ODBC Driver 18 for SQL Server')
# Auto-detect port based on driver if not explicitly set
default_port = creds.port or int(os.getenv('MCP_SQL_PORT', '1433'))
if default_port == 1433 and creds.driver:
driver_lower = creds.driver.lower()
if 'mysql' in driver_lower:
default_port = 3306
elif 'postgres' in driver_lower or 'postgresql' in driver_lower:
default_port = 5432
creds.port = default_port
return creds