import re
import logging
from typing import Dict, Any, Optional
from pathlib import Path
import ssl
from functools import wraps
import time
# Import the Pydantic model
from mcp_db.db import DatabaseCredentials
logger = logging.getLogger(__name__)
class SecurityManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SecurityManager, cls).__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self):
"""Initialize security settings"""
self.rate_limit = {
'max_queries_per_minute': 100,
'current_minute': None,
'current_count': 0
}
# Keep blocked patterns for query sanitization if needed elsewhere
self.blocked_patterns = [
r'(?i)(?:delete|drop|truncate)\s+(?:table|database)',
r'(?i)(?:alter|create)\s+(?:table|database|user)',
r'(?i)execute\s+(?:procedure|function)',
r'(?i)grant\s+|revoke\s+',
r'(?i)system_user\(\)',
r'(?i)super\s+privilege',
]
def get_ssl_context(self, ssl_ca: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Create SSL context for secure connection"""
if not ssl_ca:
logger.debug("No SSL CA provided, SSL context not created.")
return None
try:
ssl_ca_path = Path(ssl_ca)
if not ssl_ca_path.is_file():
# Log as an error, but let the connection attempt handle the failure
logger.error(f"SSL CA file not found at path: {ssl_ca}")
raise FileNotFoundError(f"SSL CA file not found: {ssl_ca}")
# Configuration options for mysql.connector's SSL arguments
ssl_options = {
'ssl_ca': str(ssl_ca_path),
# Add other options like ssl_cert, ssl_key, ssl_verify_cert if needed
# 'ssl_verify_cert': True, # Often default, but can be explicit
}
logger.info(f"SSL context options prepared using CA: {ssl_ca}")
return ssl_options
except Exception as e:
logger.error(f"Failed to prepare SSL context options: {e}", exc_info=True)
# Propagate the error to be handled during connection attempt
raise
def sanitize_query(self, query: str) -> str:
"""Sanitize SQL query (basic checks for dangerous patterns)."""
# Ensure query is a string
if not isinstance(query, str):
raise ValueError("Query must be a string.")
# Remove comments (basic removal)
query = re.sub(r'/\*.*?\*/', '', query, flags=re.DOTALL) # Handle multi-line comments
query = re.sub(r'--.*$', '', query, flags=re.MULTILINE)
query = re.sub(r'#.*$', '', query, flags=re.MULTILINE) # Handle hash comments
# Check for dangerous patterns
for pattern in self.blocked_patterns:
if re.search(pattern, query, re.IGNORECASE):
logger.warning(f"Query blocked due to forbidden pattern: {pattern}")
raise ValueError(f"Query contains forbidden pattern: {pattern}")
# Trim whitespace
sanitized_query = query.strip()
if not sanitized_query:
raise ValueError("Query cannot be empty after sanitization.")
logger.debug("Query passed basic sanitization checks.")
return sanitized_query
def check_rate_limit(self) -> bool:
"""Check if current request exceeds rate limit"""
current_time = time.time()
current_minute = int(current_time / 60)
if self.rate_limit['current_minute'] != current_minute:
# Reset count for the new minute
self.rate_limit['current_minute'] = current_minute
self.rate_limit['current_count'] = 0
logger.debug(f"Rate limit minute reset to {current_minute}")
self.rate_limit['current_count'] += 1
logger.debug(f"Rate limit count for minute {current_minute}: {self.rate_limit['current_count']}/{self.rate_limit['max_queries_per_minute']}")
if self.rate_limit['current_count'] > self.rate_limit['max_queries_per_minute']:
logger.warning(f"Rate limit exceeded ({self.rate_limit['current_count']}/{self.rate_limit['max_queries_per_minute']})")
return False # Rate limit exceeded
return True # Within rate limit
def sanitize_connection_params(self, params: DatabaseCredentials) -> None:
"""
Validate database connection parameters directly on the Pydantic model.
Raises ValueError for invalid parameters.
Note: This method modifies the validation logic to work with the model's attributes.
It no longer returns a dict, as it validates the input model directly.
"""
logger.debug(f"Sanitizing connection parameters for host: {params.host}, db: {params.database}")
# Validate host (allow hostname or IP address)
# More permissive regex: allows domain names, localhost, IPv4, IPv6
host_pattern = r"^(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}|localhost|(?:\d{1,3}\.){3}\d{1,3}|\[?[a-fA-F0-9:]+\]?)$"
if not re.match(host_pattern, params.host):
logger.error(f"Invalid host format detected: {params.host}")
raise ValueError(f"Invalid host format: {params.host}")
# Validate port
if not isinstance(params.port, int) or not (1 <= params.port <= 65535):
logger.error(f"Invalid port number detected: {params.port}")
raise ValueError(f"Invalid port number: {params.port}. Must be between 1 and 65535.")
# Validate database name (allow alphanumeric, underscore, hyphen)
db_name_pattern = r'^[a-zA-Z0-9_-]+$'
if not re.match(db_name_pattern, params.database):
logger.error(f"Invalid database name format detected: {params.database}")
raise ValueError(f"Invalid database name format: {params.database}. Use alphanumeric, underscore, or hyphen.")
# Validate username (similar to database name)
username_pattern = r'^[a-zA-Z0-9_@.-]+$' # Allow more chars often used in usernames
if not re.match(username_pattern, params.username):
logger.error(f"Invalid username format detected: {params.username}")
raise ValueError(f"Invalid username format: {params.username}")
# Validate password (check if it's set - Pydantic handles SecretStr)
if not params.password or not params.password.get_secret_value():
logger.error("Database password is required but not provided.")
raise ValueError("Database password is required.")
# Validate SSL CA path if provided
if params.ssl_ca:
ssl_ca_path = Path(params.ssl_ca)
if not ssl_ca_path.is_file():
# Log as warning, connection will fail later if file is truly needed and missing
logger.warning(f"Provided SSL CA path does not point to a file: {params.ssl_ca}")
# Optionally raise ValueError here if SSL is mandatory and file must exist
# raise ValueError(f"SSL CA file not found at path: {params.ssl_ca}")
# Validate retries and delay
if not isinstance(params.max_retries, int) or params.max_retries < 0:
logger.error(f"Invalid max_retries value: {params.max_retries}")
raise ValueError("max_retries must be a non-negative integer.")
if not isinstance(params.retry_delay, (int, float)) or params.retry_delay < 0:
logger.error(f"Invalid retry_delay value: {params.retry_delay}")
raise ValueError("retry_delay must be a non-negative number.")
logger.debug("Connection parameters passed validation.")
# No need to return anything as we are validating the passed object.
def enforce_security(func):
"""Decorator to enforce security measures like rate limiting and query sanitization."""
@wraps(func)
def wrapper(*args, **kwargs):
# Get the SecurityManager instance
# Assumes SecurityManager is initialized elsewhere (e.g., in main.py)
# If not, might need a way to access the instance created in main.py
# For simplicity here, we re-create it, but a shared instance is better.
security_manager = SecurityManager() # Or access a shared instance
# --- Check Rate Limit ---
if not security_manager.check_rate_limit():
# Raise an exception that the handle_error decorator can catch
raise Exception("Rate limit exceeded. Please try again later.") # Or a custom RateLimitError
# --- Sanitize Query (if applicable) ---
# Check if 'query' or 'query_params' exists in kwargs and needs sanitization
# This part depends on which tool is being decorated and its arguments.
# Example: If decorating query_database which takes query_params dict
if func.__name__ == 'query_database' and 'query_params' in kwargs:
# This example assumes the raw SQL isn't passed directly.
# If a raw SQL string *was* passed (which is discouraged), sanitize it:
# if 'raw_sql_query' in kwargs:
# kwargs['raw_sql_query'] = security_manager.sanitize_query(kwargs['raw_sql_query'])
# Since we build parameterized queries, direct SQL sanitization might not be needed here,
# but validation of inputs (table names, field names) within query_database is crucial.
logger.debug(f"Security checks passed for {func.__name__}.") # Placeholder log
pass # Parameterized queries handle SQL injection risk
# Execute the original function
return func(*args, **kwargs)
# Preserve original function signature for MCP introspection
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
wrapper.__annotations__ = func.__annotations__
return wrapper