Skip to main content
Glama

Doris MCP Server

Official
by apache
security.py50.8 kB
#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Doris Security Management Module Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality """ import logging import re from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import Any, Optional import sqlparse from sqlparse.sql import Statement from sqlparse.tokens import Keyword, Name from .logger import get_logger from .config import DatabaseConfig class SecurityLevel(Enum): """Security level enumeration""" PUBLIC = "public" INTERNAL = "internal" CONFIDENTIAL = "confidential" SECRET = "secret" @dataclass class AuthContext: """Authentication context for audit and session tracking""" token_id: str = "" # Token identifier for audit logging user_id: str = "" # User identifier roles: list[str] = field(default_factory=list) # User roles permissions: list[str] = field(default_factory=list) # User permissions security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level client_ip: str = "unknown" # Client IP address session_id: str = "" # Session identifier login_time: datetime = field(default_factory=datetime.utcnow) last_activity: datetime | None = None token: str = "" # Raw token for token-bound database configuration @dataclass class ValidationResult: """Validation result""" is_valid: bool error_message: str | None = None risk_level: str = "low" blocked_operations: list[str] = None def __post_init__(self): if self.blocked_operations is None: self.blocked_operations = [] @dataclass class MaskingRule: """Data masking rule""" column_pattern: str algorithm: str parameters: dict[str, Any] security_level: SecurityLevel class DorisSecurityManager: """Doris security manager Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking """ def __init__(self, config, connection_manager=None): self.config = config self.logger = get_logger(__name__) self.connection_manager = connection_manager # Initialize security components self.auth_provider = AuthenticationProvider(config, self) self.authz_provider = AuthorizationProvider(config) self.sql_validator = SQLSecurityValidator(config) self.masking_processor = DataMaskingProcessor(config) # Security rule configuration self.blocked_keywords = self._load_blocked_keywords() self.sensitive_tables = self._load_sensitive_tables() self.masking_rules = self._load_masking_rules() # Track initialization state self._initialized = False async def initialize(self): """Initialize security manager components""" if self._initialized: return try: # Initialize authentication provider (for JWT setup) await self.auth_provider.initialize() self._initialized = True self.logger.info("DorisSecurityManager initialized successfully") except Exception as e: self.logger.error(f"Failed to initialize DorisSecurityManager: {e}") raise async def shutdown(self): """Shutdown security manager components""" try: await self.auth_provider.shutdown() self._initialized = False self.logger.info("DorisSecurityManager shutdown completed") except Exception as e: self.logger.error(f"Error during DorisSecurityManager shutdown: {e}") raise def _load_blocked_keywords(self) -> set[str]: """Load blocked SQL keywords from configuration""" # Load keywords from configuration, unified source of truth if hasattr(self.config, 'get'): # Dictionary-style configuration blocked_keywords = self.config.get("blocked_keywords", []) elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'): # DorisConfig object, get through security.blocked_keywords blocked_keywords = self.config.security.blocked_keywords else: # Fallback to default if no configuration available blocked_keywords = [ "DROP", "CREATE", "ALTER", "TRUNCATE", "DELETE", "INSERT", "UPDATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", "SHUTDOWN", "KILL" ] return set(blocked_keywords) def _load_sensitive_tables(self) -> dict[str, SecurityLevel]: """Load sensitive table configuration""" default_tables = { "user_info": SecurityLevel.CONFIDENTIAL, "payment_records": SecurityLevel.SECRET, "employee_data": SecurityLevel.CONFIDENTIAL, "public_reports": SecurityLevel.PUBLIC, } if hasattr(self.config, 'get'): config_tables = self.config.get("sensitive_tables", {}) # Convert string values to SecurityLevel enum for table_name, level in config_tables.items(): if isinstance(level, str): try: default_tables[table_name] = SecurityLevel(level.lower()) except ValueError: default_tables[table_name] = SecurityLevel.INTERNAL else: default_tables[table_name] = level return default_tables else: return default_tables def _load_masking_rules(self) -> list[MaskingRule]: """Load data masking rules""" default_rules = [ MaskingRule( column_pattern=r".*phone.*|.*mobile.*", algorithm="phone_mask", parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}, security_level=SecurityLevel.INTERNAL, ), MaskingRule( column_pattern=r".*email.*", algorithm="email_mask", parameters={"mask_char": "*"}, security_level=SecurityLevel.INTERNAL, ), MaskingRule( column_pattern=r".*id_card.*|.*identity.*", algorithm="id_mask", parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}, security_level=SecurityLevel.CONFIDENTIAL, ), ] # Load custom rules from configuration custom_rules = [] if hasattr(self.config, 'get'): custom_rules = self.config.get("masking_rules", []) elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'): custom_rules = self.config.security.masking_rules for rule_config in custom_rules: if isinstance(rule_config, dict): default_rules.append(MaskingRule(**rule_config)) elif isinstance(rule_config, MaskingRule): default_rules.append(rule_config) return default_rules async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext: """Validate request authentication information Tries authentication methods in order: Token -> JWT -> OAuth Any one method succeeding allows access If all methods are disabled, returns anonymous context """ # Check if any authentication method is enabled if not (self.config.security.enable_token_auth or self.config.security.enable_jwt_auth or self.config.security.enable_oauth_auth): self.logger.debug("All authentication methods are disabled") # Return anonymous context when no authentication is enabled return AuthContext( token_id="anonymous", user_id="anonymous", roles=["anonymous"], permissions=["read"], security_level=SecurityLevel.PUBLIC, client_ip=auth_info.get("client_ip", "unknown"), session_id="anonymous_session" ) # Try authentication methods in order of preference last_error = None # 1. Try Token authentication first (most common) if self.config.security.enable_token_auth: try: return await self.auth_provider.authenticate_token(auth_info) except Exception as e: self.logger.debug(f"Token authentication failed: {e}") last_error = e # 2. Try JWT authentication if self.config.security.enable_jwt_auth: try: return await self.auth_provider.authenticate_jwt(auth_info) except Exception as e: self.logger.debug(f"JWT authentication failed: {e}") last_error = e # 3. Try OAuth authentication if self.config.security.enable_oauth_auth: try: return await self.auth_provider.authenticate_oauth(auth_info) except Exception as e: self.logger.debug(f"OAuth authentication failed: {e}") last_error = e # All enabled authentication methods failed error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded" self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}") raise ValueError(error_message) async def authorize_resource_access( self, auth_context: AuthContext, resource_uri: str ) -> bool: """Validate resource access permissions""" return await self.authz_provider.check_permission( auth_context, resource_uri, "read" ) async def validate_sql_security( self, sql: str, auth_context: AuthContext ) -> ValidationResult: """Validate SQL query security""" return await self.sql_validator.validate(sql, auth_context) async def apply_data_masking( self, data: list[dict[str, Any]], auth_context: AuthContext ) -> list[dict[str, Any]]: """Apply data masking processing""" return await self.masking_processor.process(data, auth_context) # OAuth-specific methods def get_oauth_authorization_url(self) -> tuple[str, str]: """Get OAuth authorization URL Returns: Tuple of (authorization_url, state) """ if not self.auth_provider.oauth_provider: raise ValueError("OAuth is not enabled") return self.auth_provider.oauth_provider.get_authorization_url() async def handle_oauth_callback(self, code: str, state: str) -> AuthContext: """Handle OAuth callback Args: code: Authorization code from OAuth provider state: State parameter for CSRF protection Returns: AuthContext for authenticated user """ if not self.auth_provider.oauth_provider: raise ValueError("OAuth is not enabled") return await self.auth_provider.oauth_provider.handle_callback(code, state) def get_oauth_provider_info(self) -> dict[str, Any]: """Get OAuth provider information Returns: OAuth provider information """ if not self.auth_provider.oauth_provider: return {"enabled": False} return self.auth_provider.oauth_provider.get_provider_info() # Token management methods async def create_token( self, token_id: str, expires_hours: Optional[int] = None, description: str = "", custom_token: Optional[str] = None, database_config: Optional[DatabaseConfig] = None ) -> str: """Create a new API access token Args: token_id: Unique token identifier for audit and management expires_hours: Token expiration in hours (None for no expiration) description: Token description for management purposes custom_token: Custom token string (if None, generates random token) database_config: Optional database configuration for this token Returns: Generated token string """ if not self.auth_provider.token_manager: raise ValueError("Token manager not initialized") return await self.auth_provider.token_manager.create_token( token_id=token_id, expires_hours=expires_hours, description=description, custom_token=custom_token, database_config=database_config ) async def revoke_token(self, token_id: str) -> bool: """Revoke a token by token ID Args: token_id: Token ID to revoke Returns: True if token was revoked successfully """ if not self.auth_provider.token_manager: raise ValueError("Token manager not initialized") return await self.auth_provider.token_manager.revoke_token(token_id) async def list_tokens(self) -> list[dict[str, Any]]: """List all tokens (without sensitive data) Returns: List of token information """ if not self.auth_provider.token_manager: raise ValueError("Token manager not initialized") return await self.auth_provider.token_manager.list_tokens() async def cleanup_expired_tokens(self) -> int: """Remove expired tokens and return count Returns: Number of expired tokens removed """ if not self.auth_provider.token_manager: return 0 return await self.auth_provider.token_manager.cleanup_expired_tokens() def get_token_stats(self) -> dict[str, Any]: """Get token statistics Returns: Token statistics dictionary """ if not self.auth_provider.token_manager: return {"error": "Token manager not initialized"} return self.auth_provider.token_manager.get_token_stats() async def _validate_token_database_config(self, token: str, token_info) -> None: """Validate database configuration for token immediately during authentication This ensures database connectivity issues are caught at authentication time, not during query execution, providing better user experience. Args: token: Raw authentication token token_info: TokenInfo object from token validation Raises: ValueError: If database configuration is invalid or connection fails """ try: if not self.connection_manager: self.logger.warning("Connection manager not available for immediate database validation") return # Configure and test database connection for this token success, config_source = await self.connection_manager.configure_for_token(token) if success: self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})") else: raise ValueError("Database configuration validation failed") except Exception as e: error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}" self.logger.error(error_msg) raise ValueError(error_msg) class AuthenticationProvider: """Authentication provider""" def __init__(self, config, security_manager=None): self.config = config self.logger = get_logger(__name__) self.session_cache = {} self.jwt_manager = None self.oauth_provider = None self.token_manager = None self.security_manager = security_manager # Initialize authentication providers based on individual switches auth_methods_enabled = [] # Initialize Token manager if enabled if config.security.enable_token_auth: self._initialize_token_manager() auth_methods_enabled.append("Token") # Initialize JWT manager if enabled if config.security.enable_jwt_auth: self._initialize_jwt_manager() auth_methods_enabled.append("JWT") # Initialize OAuth provider if enabled if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled): self._initialize_oauth_provider() auth_methods_enabled.append("OAuth") if auth_methods_enabled: self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}") else: self.logger.info("All authentication methods are disabled - anonymous access allowed") def _initialize_jwt_manager(self): """Initialize JWT manager""" try: from ..auth.jwt_manager import JWTManager self.jwt_manager = JWTManager(self.config) self.logger.info("JWT manager initialized") except ImportError as e: self.logger.error(f"Failed to import JWT manager: {e}") raise except Exception as e: self.logger.error(f"Failed to initialize JWT manager: {e}") raise def _initialize_token_manager(self): """Initialize Token manager""" try: from ..auth.token_manager import TokenManager self.token_manager = TokenManager(self.config) self.logger.info("Token manager initialized") except ImportError as e: self.logger.error(f"Failed to import Token manager: {e}") raise except Exception as e: self.logger.error(f"Failed to initialize Token manager: {e}") raise def _initialize_oauth_provider(self): """Initialize OAuth provider""" try: from ..auth.oauth_provider import OAuthAuthenticationProvider self.oauth_provider = OAuthAuthenticationProvider(self.config) self.logger.info("OAuth provider initialized") except ImportError as e: self.logger.error(f"Failed to import OAuth provider: {e}") raise except Exception as e: self.logger.error(f"Failed to initialize OAuth provider: {e}") raise async def initialize(self): """Initialize authentication provider asynchronously""" if self.jwt_manager: success = await self.jwt_manager.initialize() if not success: raise RuntimeError("Failed to initialize JWT manager") self.logger.info("JWT authentication provider initialized successfully") if self.token_manager: # Token manager doesn't need async initialization, just log success self.logger.info("Token authentication provider initialized successfully") if self.oauth_provider: success = await self.oauth_provider.initialize() if not success: raise RuntimeError("Failed to initialize OAuth provider") self.logger.info("OAuth authentication provider initialized successfully") async def shutdown(self): """Shutdown authentication provider""" if self.jwt_manager: await self.jwt_manager.shutdown() self.logger.info("JWT authentication provider shutdown completed") if self.token_manager: # Token manager doesn't need async shutdown, just log self.logger.info("Token authentication provider shutdown completed") if self.oauth_provider: await self.oauth_provider.shutdown() self.logger.info("OAuth authentication provider shutdown completed") async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext: """Perform token authentication""" if not self.config.security.enable_token_auth: raise ValueError("Token authentication is not enabled") return await self._authenticate_token(auth_info) async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext: """Perform JWT authentication""" if not self.config.security.enable_jwt_auth: raise ValueError("JWT authentication is not enabled") return await self._authenticate_jwt(auth_info) async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext: """Perform OAuth authentication""" if not self.config.security.enable_oauth_auth: raise ValueError("OAuth authentication is not enabled") return await self._authenticate_oauth(auth_info) async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext: """JWT authentication""" if not self.jwt_manager: raise ValueError("JWT manager not initialized") token = auth_info.get("token") if not token: # Try to extract from Authorization header authorization = auth_info.get("authorization") if authorization and authorization.startswith('Bearer '): token = authorization[7:] if not token: raise ValueError("Missing JWT token") try: # Use JWT middleware for authentication from ..auth.auth_middleware import AuthMiddleware middleware = AuthMiddleware(self.jwt_manager) return await middleware.authenticate_request(auth_info) except Exception as e: self.logger.error(f"JWT authentication failed: {e}") raise ValueError(f"JWT authentication failed: {str(e)}") async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext: """OAuth authentication""" if not self.oauth_provider: raise ValueError("OAuth provider not initialized") # Handle different OAuth authentication scenarios if "access_token" in auth_info: # Direct OAuth access token authentication return await self.oauth_provider.authenticate_with_token(auth_info["access_token"]) elif "code" in auth_info and "state" in auth_info: # OAuth callback authentication return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"]) else: raise ValueError("OAuth authentication requires either access_token or code+state") async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext: """Token authentication""" if not self.token_manager: raise ValueError("Token manager not initialized") token = auth_info.get("token") if not token: # Try to extract from Authorization header authorization = auth_info.get("authorization") if authorization and authorization.startswith('Bearer '): token = authorization[7:] elif authorization and authorization.startswith('Token '): token = authorization[6:] if not token: raise ValueError("Missing authentication token") try: # Validate token using TokenManager validation_result = await self.token_manager.validate_token(token) if not validation_result.is_valid: raise ValueError(f"Token validation failed: {validation_result.error_message}") token_info = validation_result.token_info # Immediately validate database configuration for this token if self.security_manager: await self.security_manager._validate_token_database_config(token, token_info) return AuthContext( token_id=token_info.token_id, user_id=token_info.token_id, # Use token_id as user_id for token auth roles=["token_user"], # Default role for token users permissions=["read", "write"], # Default permissions for token users security_level=SecurityLevel.INTERNAL, client_ip=auth_info.get("client_ip", "unknown"), session_id=auth_info.get("session_id", f"session_{token_info.token_id}"), login_time=datetime.utcnow(), last_activity=token_info.last_used, token=token # Store raw token for token-bound database configuration ) except Exception as e: self.logger.error(f"Token authentication failed: {e}") raise ValueError(f"Token authentication failed: {str(e)}") async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext: """Basic authentication (username password)""" username = auth_info.get("username") password = auth_info.get("password") if not username or not password: raise ValueError("Missing username or password") # Validate username password (simplified implementation) user_info = await self._validate_credentials(username, password) return AuthContext( user_id=user_info["user_id"], roles=user_info["roles"], permissions=user_info["permissions"], session_id=auth_info.get("session_id", "default"), login_time=datetime.utcnow(), security_level=SecurityLevel(user_info.get("security_level", "internal")), ) async def _validate_token(self, token: str) -> dict[str, Any]: """Validate token validity""" # Simplified implementation for testing, should parse JWT or query authentication service in practice valid_tokens = { "valid_token_123": { "user_id": "test_user", "roles": ["data_analyst"], "permissions": ["read_data"], "security_level": SecurityLevel.INTERNAL, }, "admin_token_456": { "user_id": "admin_user", "roles": ["data_admin"], "permissions": ["admin"], "security_level": SecurityLevel.SECRET, } } if token in valid_tokens: return valid_tokens[token] else: raise ValueError("Invalid token") async def _validate_credentials( self, username: str, password: str ) -> dict[str, Any]: """Validate user credentials""" # Simplified implementation for testing, should query user database in practice valid_users = { "admin": { "password": "admin123", "user_id": "admin_user", "roles": ["data_admin"], "permissions": ["admin", "read_data", "write_data"], "security_level": SecurityLevel.SECRET, }, "analyst": { "password": "analyst123", "user_id": "analyst_user", "roles": ["data_analyst"], "permissions": ["read_data"], "security_level": SecurityLevel.INTERNAL, } } if username in valid_users and valid_users[username]["password"] == password: user_info = valid_users[username].copy() del user_info["password"] # Remove password from returned info return user_info else: raise ValueError("Incorrect username or password") class AuthorizationProvider: """Authorization provider""" def __init__(self, config): self.config = config self.logger = get_logger(__name__) self.permission_cache = {} # Load sensitive tables configuration self.sensitive_tables = self._load_sensitive_tables() def _load_sensitive_tables(self) -> dict[str, SecurityLevel]: """Load sensitive table configuration""" default_tables = { "user_info": SecurityLevel.CONFIDENTIAL, "payment_records": SecurityLevel.SECRET, "employee_data": SecurityLevel.CONFIDENTIAL, "public_reports": SecurityLevel.PUBLIC, } if hasattr(self.config, 'get'): config_tables = self.config.get("sensitive_tables", {}) # Convert string values to SecurityLevel enum for table_name, level in config_tables.items(): if isinstance(level, str): try: default_tables[table_name] = SecurityLevel(level.lower()) except ValueError: default_tables[table_name] = SecurityLevel.INTERNAL else: default_tables[table_name] = level return default_tables else: return default_tables async def check_permission( self, auth_context: AuthContext, resource_uri: str, action: str ) -> bool: """Check permissions""" # Parse resource information resource_info = self._parse_resource_uri(resource_uri) # First check security level - this is mandatory if not await self._check_security_level_permission(auth_context, resource_info): return False # Then check role-based permissions if await self._check_role_permission(auth_context, resource_info, action): return True # Finally check user-based permissions if await self._check_user_permission(auth_context, resource_info, action): return True return False def _parse_resource_uri(self, uri: str) -> dict[str, str]: """Parse resource URI""" parts = uri.split("/") if len(parts) >= 3: return { "type": parts[2], # table, view, etc. "name": parts[3] if len(parts) > 3 else "", "schema": parts[4] if len(parts) > 4 else "default", } return {"type": "unknown", "name": "", "schema": "default"} async def _check_role_permission( self, auth_context: AuthContext, resource_info: dict[str, str], action: str ) -> bool: """Check role-based permissions""" # Role permission mapping role_permissions = { "data_analyst": {"table": ["read"], "view": ["read"]}, "data_admin": { "table": ["read", "write", "admin"], "view": ["read", "write", "admin"], }, } for role in auth_context.roles: role_perms = role_permissions.get(role, {}) resource_perms = role_perms.get(resource_info["type"], []) if action in resource_perms: return True return False async def _check_user_permission( self, auth_context: AuthContext, resource_info: dict[str, str], action: str ) -> bool: """Check user-based permissions""" # User-specific permission check if "admin" in auth_context.permissions: return True if action == "read" and "read_data" in auth_context.permissions: return True return False async def _check_security_level_permission( self, auth_context: AuthContext, resource_info: dict[str, str] ) -> bool: """Check security level permissions""" # Get resource security level resource_security_level = self._get_resource_security_level(resource_info) # Check if user security level is sufficient security_hierarchy = { SecurityLevel.PUBLIC: 0, SecurityLevel.INTERNAL: 1, SecurityLevel.CONFIDENTIAL: 2, SecurityLevel.SECRET: 3, } user_level = security_hierarchy.get(auth_context.security_level, 0) resource_level = security_hierarchy.get(resource_security_level, 0) # User must have higher or equal security level to access resource return user_level >= resource_level def _get_resource_security_level( self, resource_info: dict[str, str] ) -> SecurityLevel: """Get resource security level""" # Get table security level from configuration table_name = resource_info.get("name", "") # Use the loaded sensitive tables sensitive_tables = self.sensitive_tables # Convert string values to SecurityLevel enum if needed security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL) if isinstance(security_level, str): try: security_level = SecurityLevel(security_level.lower()) except ValueError: security_level = SecurityLevel.INTERNAL return security_level class SQLSecurityValidator: """SQL security validator""" def __init__(self, config): self.config = config self.logger = get_logger(__name__) # Handle DorisConfig object or dictionary configuration if hasattr(config, 'get'): # Dictionary configuration self.blocked_keywords = set(config.get("blocked_keywords", [])) self.max_query_complexity = config.get("max_query_complexity", 100) self.enable_security_check = config.get("enable_security_check", True) elif hasattr(config, 'security'): # DorisConfig object with security attribute - unified source from config self.blocked_keywords = set(config.security.blocked_keywords) self.max_query_complexity = config.security.max_query_complexity self.enable_security_check = getattr(config.security, 'enable_security_check', True) else: # Fallback to default if no configuration available self.blocked_keywords = set([ "DROP", "CREATE", "ALTER", "TRUNCATE", "DELETE", "INSERT", "UPDATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", "SHUTDOWN", "KILL" ]) self.max_query_complexity = 100 self.enable_security_check = True async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult: """Validate SQL query security""" # If security check is disabled, always return valid if not self.enable_security_check: self.logger.debug("SQL security check is disabled, allowing all queries") return ValidationResult(is_valid=True) try: # Parse SQL statement parsed = sqlparse.parse(sql)[0] # Check blocked operations first (more specific) keyword_result = await self._check_blocked_keywords(parsed) if not keyword_result.is_valid: return keyword_result # Check SQL injection risks injection_result = await self._check_sql_injection(sql, parsed) if not injection_result.is_valid: return injection_result # Check query complexity complexity_result = await self._check_query_complexity(parsed) if not complexity_result.is_valid: return complexity_result # Check table access permissions table_result = await self._check_table_access(parsed, auth_context) if not table_result.is_valid: return table_result return ValidationResult(is_valid=True) except Exception as e: self.logger.error(f"SQL security validation failed: {e}") return ValidationResult( is_valid=False, error_message=f"SQL parsing error: {str(e)}", risk_level="high", ) async def _check_sql_injection( self, sql: str, parsed: Statement ) -> ValidationResult: """Check SQL injection risks""" # Check common SQL injection patterns injection_patterns = [ r"(?i)(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])\s+[\s\S]*?\s+(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])", r"(\s|^)(or|and)\s+\d+\s*=\s*\d+", r"(\s|^)(or|and)\s+['\"].*['\"]", r";\s*(drop|delete|truncate|alter|create)", r"(exec|execute|sp_|xp_)", r"(script|javascript|vbscript)", r"(char|ascii|substring|concat)\s*\(", ] sql_lower = sql.lower() for pattern in injection_patterns: if re.search(pattern, sql_lower, re.IGNORECASE): return ValidationResult( is_valid=False, error_message="Potential SQL injection risk detected", risk_level="high", ) # Check suspicious quotes and comments if self._has_suspicious_quotes_or_comments(sql): return ValidationResult( is_valid=False, error_message="Suspicious quote or comment pattern detected", risk_level="medium", ) return ValidationResult(is_valid=True) def _has_suspicious_quotes_or_comments(self, sql: str) -> bool: """Check suspicious quote and comment patterns""" # Check unmatched quotes single_quotes = sql.count("'") double_quotes = sql.count('"') if single_quotes % 2 != 0 or double_quotes % 2 != 0: return True # Check SQL comments if "--" in sql or "/*" in sql: return True return False async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult: """Check blocked keywords""" blocked_operations = [] # Check all tokens in the parsed statement for token in parsed.flatten(): # Check if token is a keyword (including DML/DDL) or name that matches blocked operations if (token.ttype is Keyword or token.ttype is Name or (token.ttype and str(token.ttype).startswith('Token.Keyword'))): token_value = token.value.upper().strip() if token_value in self.blocked_keywords: blocked_operations.append(token_value) # Also check for DDL/DML keywords in token values elif hasattr(token, 'value') and token.value: token_value = token.value.upper().strip() for blocked_keyword in self.blocked_keywords: if blocked_keyword in token_value: blocked_operations.append(blocked_keyword) if blocked_operations: return ValidationResult( is_valid=False, error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}", risk_level="high", blocked_operations=list(set(blocked_operations)), ) return ValidationResult(is_valid=True) async def _check_query_complexity(self, parsed: Statement) -> ValidationResult: """Check query complexity""" complexity_score = 0 # Calculate complexity score for token in parsed.flatten(): if token.ttype is Keyword: keyword = token.value.upper() if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]: complexity_score += 10 elif keyword in ["UNION", "INTERSECT", "EXCEPT"]: complexity_score += 15 elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]: complexity_score += 5 elif keyword in ["SUBQUERY", "EXISTS", "IN"]: complexity_score += 8 if complexity_score > self.max_query_complexity: return ValidationResult( is_valid=False, error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})", risk_level="medium", ) return ValidationResult(is_valid=True) async def _check_table_access( self, parsed: Statement, auth_context: AuthContext ) -> ValidationResult: """Check table access permissions""" # Extract table names from query tables = self._extract_table_names(parsed) # Check access permissions for each table unauthorized_tables = [] for table in tables: # Should call authorization provider to check permissions # Simplified implementation, assume some tables require special permissions if ( table.lower() in ["sensitive_data", "admin_logs"] and "admin" not in auth_context.roles ): unauthorized_tables.append(table) if unauthorized_tables: return ValidationResult( is_valid=False, error_message=f"No access to tables: {', '.join(unauthorized_tables)}", risk_level="high", ) return ValidationResult(is_valid=True) def _extract_table_names(self, parsed: Statement) -> list[str]: """Extract table names from SQL statement""" tables = [] # Simplified table name extraction logic tokens = list(parsed.flatten()) for i, token in enumerate(tokens): if token.ttype is Keyword and token.value.upper() == "FROM": # Find table name after FROM for j in range(i + 1, len(tokens)): next_token = tokens[j] if next_token.ttype is Name: tables.append(next_token.value) break elif next_token.ttype is Keyword: break return tables class DataMaskingProcessor: """Data masking processor""" def __init__(self, config): self.config = config self.logger = get_logger(__name__) self.masking_algorithms = self._init_masking_algorithms() self.masking_rules = self._load_masking_rules() def _load_masking_rules(self) -> list[MaskingRule]: """Load data masking rules""" default_rules = [ MaskingRule( column_pattern=r".*phone.*|.*mobile.*", algorithm="phone_mask", parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}, security_level=SecurityLevel.INTERNAL, ), MaskingRule( column_pattern=r".*email.*", algorithm="email_mask", parameters={"mask_char": "*"}, security_level=SecurityLevel.INTERNAL, ), MaskingRule( column_pattern=r".*id_card.*|.*identity.*", algorithm="id_mask", parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}, security_level=SecurityLevel.CONFIDENTIAL, ), ] # Load custom rules from configuration if hasattr(self.config, 'get'): custom_rules = self.config.get("masking_rules", []) for rule_config in custom_rules: if isinstance(rule_config, dict): # Convert string security level to enum if 'security_level' in rule_config and isinstance(rule_config['security_level'], str): try: rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower()) except ValueError: rule_config['security_level'] = SecurityLevel.INTERNAL default_rules.append(MaskingRule(**rule_config)) elif isinstance(rule_config, MaskingRule): default_rules.append(rule_config) return default_rules def _init_masking_algorithms(self) -> dict[str, callable]: """Initialize masking algorithms""" return { "phone_mask": self._mask_phone, "email_mask": self._mask_email, "id_mask": self._mask_id_card, "name_mask": self._mask_name, "partial_mask": self._mask_partial, } async def process( self, data: list[dict[str, Any]], auth_context: AuthContext ) -> list[dict[str, Any]]: """Process data masking""" if not data: return data # Get applicable masking rules applicable_rules = self._get_applicable_rules(auth_context) masked_data = [] for row in data: masked_row = {} for column, value in row.items(): masked_value = await self._apply_masking_rules( column, value, applicable_rules ) masked_row[column] = masked_value masked_data.append(masked_row) return masked_data def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]: """Get applicable masking rules""" applicable_rules = [] for rule in self.masking_rules: # Decide whether to apply masking rules based on user security level if self._should_apply_rule(rule, auth_context): applicable_rules.append(rule) return applicable_rules def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool: """Determine whether masking rule should be applied""" # Admin users can see original data if "admin" in auth_context.roles: return False # Decide based on security level security_hierarchy = { SecurityLevel.PUBLIC: 0, SecurityLevel.INTERNAL: 1, SecurityLevel.CONFIDENTIAL: 2, SecurityLevel.SECRET: 3, } user_level = security_hierarchy.get(auth_context.security_level, 0) rule_level = security_hierarchy.get(rule.security_level, 0) # Apply masking if user level is less than or equal to rule level return user_level <= rule_level async def _apply_masking_rules( self, column: str, value: Any, rules: list[MaskingRule] ) -> Any: """Apply masking rules""" if value is None: return value for rule in rules: if re.match(rule.column_pattern, column, re.IGNORECASE): algorithm = self.masking_algorithms.get(rule.algorithm) if algorithm: return algorithm(str(value), rule.parameters) return value def _mask_phone(self, value: str, params: dict[str, Any]) -> str: """Phone number masking""" if len(value) < 7: return value mask_char = params.get("mask_char", "*") keep_prefix = params.get("keep_prefix", 3) keep_suffix = params.get("keep_suffix", 4) if len(value) <= keep_prefix + keep_suffix: return mask_char * len(value) prefix = value[:keep_prefix] suffix = value[-keep_suffix:] middle_length = len(value) - keep_prefix - keep_suffix return prefix + mask_char * middle_length + suffix def _mask_email(self, value: str, params: dict[str, Any]) -> str: """Email masking""" if "@" not in value: return value mask_char = params.get("mask_char", "*") local, domain = value.split("@", 1) if len(local) <= 2: masked_local = mask_char * len(local) else: masked_local = local[0] + mask_char * (len(local) - 2) + local[-1] return f"{masked_local}@{domain}" def _mask_id_card(self, value: str, params: dict[str, Any]) -> str: """ID card number masking""" if len(value) < 10: return value mask_char = params.get("mask_char", "*") keep_prefix = params.get("keep_prefix", 6) keep_suffix = params.get("keep_suffix", 4) if len(value) <= keep_prefix + keep_suffix: return mask_char * len(value) prefix = value[:keep_prefix] suffix = value[-keep_suffix:] middle_length = len(value) - keep_prefix - keep_suffix return prefix + mask_char * middle_length + suffix def _mask_name(self, value: str, params: dict[str, Any]) -> str: """Name masking""" if len(value) <= 1: return value mask_char = params.get("mask_char", "*") if len(value) == 2: return value[0] + mask_char else: return value[0] + mask_char * (len(value) - 2) + value[-1] def _mask_partial(self, value: str, params: dict[str, Any]) -> str: """Partial masking""" mask_char = params.get("mask_char", "*") mask_ratio = params.get("mask_ratio", 0.5) mask_length = int(len(value) * mask_ratio) start_pos = (len(value) - mask_length) // 2 result = list(value) for i in range(start_pos, start_pos + mask_length): if i < len(result): result[i] = mask_char return "".join(result)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/apache/doris-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server