Skip to main content
Glama
database_scope_checker.py7.73 kB
""" 数据库范围检查器 用于检测和限制SQL查询中的跨数据库访问 """ import re import logging from typing import Set, Optional, List, Tuple from enum import Enum logger = logging.getLogger("mysql_server") class DatabaseAccessLevel(Enum): """数据库访问级别""" STRICT = "strict" # 严格模式:只能访问指定数据库 RESTRICTED = "restricted" # 限制模式:允许访问指定数据库和系统库 PERMISSIVE = "permissive" # 宽松模式:允许访问所有数据库(默认) class DatabaseScopeViolation(Exception): """数据库范围违规异常""" pass class DatabaseScopeChecker: """数据库范围检查器""" # 系统数据库列表 SYSTEM_DATABASES = { 'information_schema', 'mysql', 'performance_schema', 'sys' } # 跨数据库查询模式 CROSS_DB_PATTERNS = [ # database.table 格式 r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', # SHOW TABLES FROM database r'\bSHOW\s+(?:FULL\s+)?TABLES\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\b', # USE database r'\bUSE\s+([a-zA-Z_][a-zA-Z0-9_]*)\b', # SELECT ... FROM database.table r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', # JOIN database.table r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', # INSERT INTO database.table r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', # UPDATE database.table r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', # DELETE FROM database.table r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b', ] def __init__(self, allowed_database: Optional[str] = None, access_level: DatabaseAccessLevel = DatabaseAccessLevel.PERMISSIVE): """ 初始化数据库范围检查器 Args: allowed_database: 允许访问的数据库名称 access_level: 访问级别 """ self.allowed_database = allowed_database self.access_level = access_level self.is_enabled = allowed_database is not None and access_level != DatabaseAccessLevel.PERMISSIVE logger.debug(f"数据库范围检查器初始化: 允许数据库={allowed_database}, 访问级别={access_level.value}, 启用={self.is_enabled}") def check_query(self, sql_query: str) -> Tuple[bool, List[str]]: """ 检查SQL查询是否违反数据库范围限制 Args: sql_query: SQL查询语句 Returns: (是否允许, 违规详情列表) """ if not self.is_enabled: return True, [] violations = [] # 提取查询中涉及的数据库 referenced_databases = self._extract_databases(sql_query) for db_name in referenced_databases: if not self._is_database_allowed(db_name): violations.append(f"不允许访问数据库: {db_name}") # 检查特殊查询类型 special_violations = self._check_special_queries(sql_query) violations.extend(special_violations) is_allowed = len(violations) == 0 if violations: logger.warning(f"数据库范围检查失败: {violations}") return is_allowed, violations def _extract_databases(self, sql_query: str) -> Set[str]: """提取SQL查询中涉及的数据库名称""" databases = set() # 标准化SQL(转换为大写,去除多余空格) normalized_sql = re.sub(r'\s+', ' ', sql_query.upper().strip()) for pattern in self.CROSS_DB_PATTERNS: matches = re.finditer(pattern, normalized_sql, re.IGNORECASE) for match in matches: # 第一个捕获组通常是数据库名 if match.groups(): db_name = match.group(1).lower() # 过滤掉非数据库名的匹配(如函数名等) if self._is_valid_database_name(db_name): databases.add(db_name) return databases def _is_valid_database_name(self, name: str) -> bool: """检查是否是有效的数据库名称""" # 数据库名称规则:字母、数字、下划线,不能以数字开头 return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name)) def _is_database_allowed(self, db_name: str) -> bool: """检查数据库是否被允许访问""" db_name_lower = db_name.lower() # 检查是否是允许的主数据库 if self.allowed_database and db_name_lower == self.allowed_database.lower(): return True # 根据访问级别决定是否允许系统数据库 if self.access_level == DatabaseAccessLevel.RESTRICTED: if db_name_lower in self.SYSTEM_DATABASES: return True return False def _check_special_queries(self, sql_query: str) -> List[str]: """检查特殊类型的查询""" violations = [] normalized_sql = sql_query.upper().strip() # 检查SHOW DATABASES查询 if re.search(r'\bSHOW\s+DATABASES\b', normalized_sql): if self.access_level == DatabaseAccessLevel.STRICT: violations.append("严格模式下不允许执行 SHOW DATABASES") # 检查USE语句 if re.search(r'\bUSE\s+', normalized_sql): violations.append("不允许使用 USE 语句切换数据库") # 检查系统表访问 system_table_patterns = [ r'\bmysql\.user\b', r'\bmysql\.db\b', r'\binformation_schema\.', r'\bperformance_schema\.', r'\bsys\.' ] for pattern in system_table_patterns: if re.search(pattern, normalized_sql, re.IGNORECASE): if self.access_level == DatabaseAccessLevel.STRICT: violations.append(f"严格模式下不允许访问系统表") break return violations def get_allowed_databases(self) -> Set[str]: """获取允许访问的数据库列表""" allowed = set() if self.allowed_database: allowed.add(self.allowed_database.lower()) if self.access_level == DatabaseAccessLevel.RESTRICTED: allowed.update(self.SYSTEM_DATABASES) return allowed def is_cross_database_query(self, sql_query: str) -> bool: """检查是否是跨数据库查询""" referenced_dbs = self._extract_databases(sql_query) return len(referenced_dbs) > 0 # 便捷函数 def create_database_checker(allowed_database: Optional[str] = None, access_level: str = "permissive") -> DatabaseScopeChecker: """ 创建数据库范围检查器的便捷函数 Args: allowed_database: 允许访问的数据库名称 access_level: 访问级别字符串 (strict/restricted/permissive) Returns: DatabaseScopeChecker实例 """ try: level = DatabaseAccessLevel(access_level.lower()) except ValueError: logger.warning(f"无效的访问级别: {access_level},使用默认的 permissive") level = DatabaseAccessLevel.PERMISSIVE return DatabaseScopeChecker(allowed_database, level)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/mangooer/mysql-mcp-server-sse'

If you have feedback or need assistance with the MCP directory API, please join our Discord server