Skip to main content
Glama
security.py21.6 kB
""" 安全模块 提供输入验证、权限控制、数据保护和安全审计功能 """ import os import re import hashlib import hmac import secrets import json import time import ipaddress from typing import Dict, Any, Optional, List, Union, Callable, Set from dataclasses import dataclass, asdict from enum import Enum from pathlib import Path import logging from datetime import datetime, timedelta from .compatibility import get_system_info class SecurityLevel(Enum): """安全级别""" LOW = "low" MEDIUM = "medium" HIGH = "high" CRITICAL = "critical" class ThreatType(Enum): """威胁类型""" PATH_TRAVERSAL = "path_traversal" INJECTION = "injection" XSS = "xss" CSRF = "csrf" DOS = "dos" BRUTE_FORCE = "brute_force" UNAUTHORIZED_ACCESS = "unauthorized_access" DATA_EXFILTRATION = "data_exfiltration" @dataclass class SecurityEvent: """安全事件""" event_type: str threat_type: ThreatType description: str severity: SecurityLevel timestamp: float source_ip: Optional[str] = None user_id: Optional[str] = None details: Dict[str, Any] = None blocked: bool = False action_taken: str = "" def __post_init__(self): if self.details is None: self.details = {} if self.timestamp == 0: self.timestamp = time.time() @dataclass class SecurityPolicy: """安全策略""" name: str enabled: bool = True level: SecurityLevel = SecurityLevel.MEDIUM max_file_size: int = 100 * 1024 * 1024 # 100MB allowed_extensions: Set[str] = None blocked_extensions: Set[str] = None max_path_length: int = 4096 rate_limit_requests: int = 1000 rate_limit_window: int = 60 # seconds allowed_ips: Set[str] = None blocked_ips: Set[str] = None require_authentication: bool = False audit_logging: bool = True def __post_init__(self): if self.allowed_extensions is None: self.allowed_extensions = {'.py', '.js', '.md', '.json', '.yaml', '.yml', '.txt'} if self.blocked_extensions is None: self.blocked_extensions = {'.exe', '.bat', '.cmd', '.sh', '.ps1', '.scr', '.vbs', '.jsf'} if self.allowed_ips is None: self.allowed_ips = set() if self.blocked_ips is None: self.blocked_ips = set() class InputValidator: """输入验证器""" def __init__(self, policy: SecurityPolicy): self.policy = policy self.logger = logging.getLogger(__name__) def validate_path(self, path: str, base_path: Optional[str] = None) -> bool: """验证路径安全""" if not path or len(path) > self.policy.max_path_length: return False # 检查路径遍历攻击 if '..' in path: return False # 检查非法字符 illegal_chars = ['<', '>', ':', '"', '|', '?', '*'] for char in illegal_chars: if char in path: return False # 如果指定了基础路径,检查是否在基础路径内 if base_path: try: full_path = Path(base_path) / path base = Path(base_path).resolve() resolved = full_path.resolve() if not str(resolved).startswith(str(base)): return False except (ValueError, OSError): return False return True def validate_file_extension(self, filename: str) -> bool: """验证文件扩展名""" ext = Path(filename).suffix.lower() # 检查阻止的扩展名 if ext in self.policy.blocked_extensions: return False # 如果有允许的扩展名列表,检查是否在允许列表中 if self.policy.allowed_extensions and ext not in self.policy.allowed_extensions: return False return True def validate_file_size(self, size: int) -> bool: """验证文件大小""" return size <= self.policy.max_file_size def validate_input_text(self, text: str, max_length: int = 10000) -> bool: """验证输入文本""" if not text or len(text) > max_length: return False # 检查潜在的注入攻击 dangerous_patterns = [ r'<script[^>]*>.*?</script>', # XSS r'javascript:', # JavaScript协议 r'on\w+\s*=', # 事件处理器 r'\b(exec|eval|system|shell|open)\s*\(', # 代码执行 r'\$\(.+?\)', # 命令替换 ] for pattern in dangerous_patterns: if re.search(pattern, text, re.IGNORECASE): return False return True def sanitize_input(self, text: str) -> str: """清理输入文本""" if not text: return "" # 移除潜在的HTML标签 text = re.sub(r'<[^>]+>', '', text) # 移除危险字符 dangerous_chars = ['\x00', '\r', '\n', '\t', '<', '>', '&', '"', "'", '\\'] for char in dangerous_chars: text = text.replace(char, '') return text.strip() class AccessController: """访问控制器""" def __init__(self, policy: SecurityPolicy): self.policy = policy self.access_log: List[Dict[str, Any]] = [] self.failed_attempts: Dict[str, int] = {} self.blocked_ips: Set[str] = set(policy.blocked_ips) self.logger = logging.getLogger(__name__) def check_ip_access(self, ip_address: str) -> bool: """检查IP访问权限""" # 检查是否在阻止列表中 if ip_address in self.blocked_ips: return False # 如果有允许列表,检查是否在允许列表中 if self.policy.allowed_ips and ip_address not in self.policy.allowed_ips: return False return True def check_rate_limit(self, ip_address: str) -> bool: """检查速率限制""" current_time = time.time() window_start = current_time - self.policy.rate_limit_window # 清理过期的尝试记录 self.access_log = [ entry for entry in self.access_log if entry.get('timestamp', 0) > window_start ] # 计算当前窗口内的请求数 recent_requests = sum( 1 for entry in self.access_log if entry.get('ip') == ip_address ) return recent_requests < self.policy.rate_limit_requests def record_access(self, ip_address: str, user_id: Optional[str] = None, success: bool = True): """记录访问""" self.access_log.append({ 'ip': ip_address, 'user_id': user_id, 'timestamp': time.time(), 'success': success }) if not success: self.failed_attempts[ip_address] = self.failed_attempts.get(ip_address, 0) + 1 # 如果失败次数过多,阻止IP if self.failed_attempts[ip_address] > 5: self.blocked_ips.add(ip_address) self.logger.warning(f"IP {ip_address} 因多次失败尝试被阻止") def check_authentication(self, token: str) -> bool: """检查身份验证""" if not self.policy.require_authentication: return True # 简单的令牌验证(实际应用中应使用更安全的方法) if not token: return False # 这里可以实现JWT验证或其他令牌验证逻辑 # 为了演示,我们使用简单的哈希检查 expected_hash = hashlib.sha256(b"secret_token").hexdigest() return hmac.compare_digest( hashlib.sha256(token.encode()).hexdigest(), expected_hash ) class DataProtector: """数据保护器""" def __init__(self, policy: SecurityPolicy): self.policy = policy self.encryption_key = secrets.token_bytes(32) self.logger = logging.getLogger(__name__) def encrypt_data(self, data: str) -> str: """加密数据""" try: import cryptography.fernet fernet = cryptography.fernet.Fernet( hashlib.sha256(self.encryption_key).digest() ) return fernet.encrypt(data.encode()).decode() except ImportError: # 如果没有cryptography包,使用简单的编码 self.logger.warning("cryptography包未安装,使用简单编码代替") import base64 return base64.b64encode(data.encode()).decode() def decrypt_data(self, encrypted_data: str) -> str: """解密数据""" try: import cryptography.fernet fernet = cryptography.fernet.Fernet( hashlib.sha256(self.encryption_key).digest() ) return fernet.decrypt(encrypted_data.encode()).decode() except ImportError: # 如果没有cryptography包,使用简单的解码 import base64 return base64.b64decode(encrypted_data.encode()).decode() def hash_data(self, data: str, salt: Optional[str] = None) -> str: """哈希数据""" if salt is None: salt = secrets.token_hex(16) return hashlib.pbkdf2_hmac( 'sha256', data.encode(), salt.encode(), 100000 ).hex() def generate_token(self, expires_in: int = 3600) -> str: """生成令牌""" payload = { 'exp': time.time() + expires_in, 'iat': time.time(), 'jti': secrets.token_urlsafe(32) } # 简单的令牌生成(实际应用中应使用JWT) token_data = json.dumps(payload).encode() signature = hmac.new( self.encryption_key, token_data, hashlib.sha256 ).hexdigest() return f"{signature}.{secrets.token_urlsafe(32)}" class SecurityAuditor: """安全审计器""" def __init__(self, policy: SecurityPolicy): self.policy = policy self.security_events: List[SecurityEvent] = [] self.logger = logging.getLogger(__name__) def log_security_event(self, event: SecurityEvent) -> None: """记录安全事件""" self.security_events.append(event) if self.policy.audit_logging: log_level = { SecurityLevel.LOW: logging.INFO, SecurityLevel.MEDIUM: logging.WARNING, SecurityLevel.HIGH: logging.ERROR, SecurityLevel.CRITICAL: logging.CRITICAL }.get(event.severity, logging.INFO) self.logger.log( log_level, f"安全事件: {event.event_type} - {event.description}", extra={ 'threat_type': event.threat_type.value, 'source_ip': event.source_ip, 'user_id': event.user_id, 'details': event.details } ) def detect_threats(self, data: Dict[str, Any]) -> List[SecurityEvent]: """检测威胁""" events = [] # 检测路径遍历 if 'path' in data: path = data['path'] if '..' in path or path.startswith('/'): events.append(SecurityEvent( event_type="path_traversal_attempt", threat_type=ThreatType.PATH_TRAVERSAL, description=f"检测到路径遍历尝试: {path}", severity=SecurityLevel.HIGH, source_ip=data.get('source_ip'), user_id=data.get('user_id'), details={'path': path}, blocked=True, action_taken="blocked" )) # 检测注入攻击 if 'input' in data: input_text = data['input'] injection_patterns = [ r'\b(exec|eval|system|shell)\s*\(', r';\s*(rm|del|format)\s+', r'<script[^>]*>', r'javascript:', ] for pattern in injection_patterns: if re.search(pattern, input_text, re.IGNORECASE): events.append(SecurityEvent( event_type="injection_attempt", threat_type=ThreatType.INJECTION, description=f"检测到注入尝试: {input_text[:50]}...", severity=SecurityLevel.CRITICAL, source_ip=data.get('source_ip'), user_id=data.get('user_id'), details={'input': input_text, 'pattern': pattern}, blocked=True, action_taken="blocked" )) return events def generate_security_report(self) -> Dict[str, Any]: """生成安全报告""" # 按威胁类型分组 threat_counts = {} for event in self.security_events: threat_type = event.threat_type.value threat_counts[threat_type] = threat_counts.get(threat_type, 0) + 1 # 按严重程度分组 severity_counts = {} for event in self.security_events: severity = event.severity.value severity_counts[severity] = severity_counts.get(severity, 0) + 1 # 最近事件 recent_events = sorted( self.security_events, key=lambda e: e.timestamp, reverse=True )[:10] return { 'total_events': len(self.security_events), 'threat_distribution': threat_counts, 'severity_distribution': severity_counts, 'recent_events': [ { 'event_type': e.event_type, 'threat_type': e.threat_type.value, 'description': e.description, 'severity': e.severity.value, 'timestamp': e.timestamp, 'blocked': e.blocked } for e in recent_events ], 'generated_at': datetime.now().isoformat() } class SecurityManager: """安全管理器""" def __init__(self, policy: Optional[SecurityPolicy] = None): self.policy = policy or SecurityPolicy("default") self.validator = InputValidator(self.policy) self.access_controller = AccessController(self.policy) self.data_protector = DataProtector(self.policy) self.auditor = SecurityAuditor(self.policy) self.logger = logging.getLogger(__name__) def validate_request(self, request_data: Dict[str, Any]) -> List[SecurityEvent]: """验证请求""" events = [] # 检测威胁 events.extend(self.auditor.detect_threats(request_data)) # 验证路径 if 'path' in request_data: if not self.validator.validate_path( request_data['path'], request_data.get('base_path') ): events.append(SecurityEvent( event_type="invalid_path", threat_type=ThreatType.PATH_TRAVERSAL, description="无效的路径格式", severity=SecurityLevel.MEDIUM, source_ip=request_data.get('source_ip'), user_id=request_data.get('user_id'), details={'path': request_data['path']}, blocked=True, action_taken="blocked" )) # 验证文件 if 'filename' in request_data: if not self.validator.validate_file_extension(request_data['filename']): events.append(SecurityEvent( event_type="invalid_file_type", threat_type=ThreatType.DATA_EXFILTRATION, description="不允许的文件类型", severity=SecurityLevel.MEDIUM, source_ip=request_data.get('source_ip'), user_id=request_data.get('user_id'), details={'filename': request_data['filename']}, blocked=True, action_taken="blocked" )) # 验证文件大小 if 'file_size' in request_data: if not self.validator.validate_file_size(request_data['file_size']): events.append(SecurityEvent( event_type="file_too_large", threat_type=ThreatType.DOS, description="文件大小超出限制", severity=SecurityLevel.MEDIUM, source_ip=request_data.get('source_ip'), user_id=request_data.get('user_id'), details={'file_size': request_data['file_size']}, blocked=True, action_taken="blocked" )) # 记录事件 for event in events: self.auditor.log_security_event(event) return events def check_access(self, ip_address: str, user_id: Optional[str] = None, token: Optional[str] = None) -> bool: """检查访问权限""" # 检查IP访问权限 if not self.access_controller.check_ip_access(ip_address): self.auditor.log_security_event(SecurityEvent( event_type="ip_blocked", threat_type=ThreatType.UNAUTHORIZED_ACCESS, description=f"IP访问被拒绝: {ip_address}", severity=SecurityLevel.MEDIUM, source_ip=ip_address, user_id=user_id, blocked=True, action_taken="blocked" )) return False # 检查速率限制 if not self.access_controller.check_rate_limit(ip_address): self.auditor.log_security_event(SecurityEvent( event_type="rate_limit_exceeded", threat_type=ThreatType.DOS, description=f"超过速率限制: {ip_address}", severity=SecurityLevel.HIGH, source_ip=ip_address, user_id=user_id, blocked=True, action_taken="blocked" )) return False # 检查身份验证 if not self.access_controller.check_authentication(token): self.auditor.log_security_event(SecurityEvent( event_type="authentication_failed", threat_type=ThreatType.UNAUTHORIZED_ACCESS, description="身份验证失败", severity=SecurityLevel.HIGH, source_ip=ip_address, user_id=user_id, blocked=True, action_taken="blocked" )) return False # 记录成功访问 self.access_controller.record_access(ip_address, user_id, True) return True def protect_data(self, data: str, operation: str = "encrypt") -> str: """保护数据""" if operation == "encrypt": return self.data_protector.encrypt_data(data) elif operation == "decrypt": return self.data_protector.decrypt_data(data) elif operation == "hash": return self.data_protector.hash_data(data) else: raise ValueError(f"不支持的操作: {operation}") def generate_token(self, expires_in: int = 3600) -> str: """生成安全令牌""" return self.data_protector.generate_token(expires_in) def get_security_status(self) -> Dict[str, Any]: """获取安全状态""" return { 'policy': asdict(self.policy), 'security_report': self.auditor.generate_security_report(), 'blocked_ips': list(self.access_controller.blocked_ips), 'system_info': { 'platform': get_system_info().platform.value, 'current_time': datetime.now().isoformat() } } # 全局安全管理器实例 global_security_manager = SecurityManager() def validate_request(request_data: Dict[str, Any]) -> List[SecurityEvent]: """验证请求便捷函数""" return global_security_manager.validate_request(request_data) def check_access(ip_address: str, user_id: Optional[str] = None, token: Optional[str] = None) -> bool: """检查访问权限便捷函数""" return global_security_manager.check_access(ip_address, user_id, token) def protect_data(data: str, operation: str = "encrypt") -> str: """保护数据便捷函数""" return global_security_manager.protect_data(data, operation) def generate_security_token(expires_in: int = 3600) -> str: """生成安全令牌便捷函数""" return global_security_manager.generate_token(expires_in) def get_security_status() -> Dict[str, Any]: """获取安全状态便捷函数""" return global_security_manager.get_security_status()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kscz0000/Zhiwen-Assistant-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server