"""
安全模块
提供输入验证、权限控制、数据保护和安全审计功能
"""
import os
import re
import hashlib
import hmac
import secrets
import json
import time
import ipaddress
from typing import Dict, Any, Optional, List, Union, Callable, Set
from dataclasses import dataclass, asdict
from enum import Enum
from pathlib import Path
import logging
from datetime import datetime, timedelta
from .compatibility import get_system_info
class SecurityLevel(Enum):
"""安全级别"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ThreatType(Enum):
"""威胁类型"""
PATH_TRAVERSAL = "path_traversal"
INJECTION = "injection"
XSS = "xss"
CSRF = "csrf"
DOS = "dos"
BRUTE_FORCE = "brute_force"
UNAUTHORIZED_ACCESS = "unauthorized_access"
DATA_EXFILTRATION = "data_exfiltration"
@dataclass
class SecurityEvent:
"""安全事件"""
event_type: str
threat_type: ThreatType
description: str
severity: SecurityLevel
timestamp: float
source_ip: Optional[str] = None
user_id: Optional[str] = None
details: Dict[str, Any] = None
blocked: bool = False
action_taken: str = ""
def __post_init__(self):
if self.details is None:
self.details = {}
if self.timestamp == 0:
self.timestamp = time.time()
@dataclass
class SecurityPolicy:
"""安全策略"""
name: str
enabled: bool = True
level: SecurityLevel = SecurityLevel.MEDIUM
max_file_size: int = 100 * 1024 * 1024 # 100MB
allowed_extensions: Set[str] = None
blocked_extensions: Set[str] = None
max_path_length: int = 4096
rate_limit_requests: int = 1000
rate_limit_window: int = 60 # seconds
allowed_ips: Set[str] = None
blocked_ips: Set[str] = None
require_authentication: bool = False
audit_logging: bool = True
def __post_init__(self):
if self.allowed_extensions is None:
self.allowed_extensions = {'.py', '.js', '.md', '.json', '.yaml', '.yml', '.txt'}
if self.blocked_extensions is None:
self.blocked_extensions = {'.exe', '.bat', '.cmd', '.sh', '.ps1', '.scr', '.vbs', '.jsf'}
if self.allowed_ips is None:
self.allowed_ips = set()
if self.blocked_ips is None:
self.blocked_ips = set()
class InputValidator:
"""输入验证器"""
def __init__(self, policy: SecurityPolicy):
self.policy = policy
self.logger = logging.getLogger(__name__)
def validate_path(self, path: str, base_path: Optional[str] = None) -> bool:
"""验证路径安全"""
if not path or len(path) > self.policy.max_path_length:
return False
# 检查路径遍历攻击
if '..' in path:
return False
# 检查非法字符
illegal_chars = ['<', '>', ':', '"', '|', '?', '*']
for char in illegal_chars:
if char in path:
return False
# 如果指定了基础路径,检查是否在基础路径内
if base_path:
try:
full_path = Path(base_path) / path
base = Path(base_path).resolve()
resolved = full_path.resolve()
if not str(resolved).startswith(str(base)):
return False
except (ValueError, OSError):
return False
return True
def validate_file_extension(self, filename: str) -> bool:
"""验证文件扩展名"""
ext = Path(filename).suffix.lower()
# 检查阻止的扩展名
if ext in self.policy.blocked_extensions:
return False
# 如果有允许的扩展名列表,检查是否在允许列表中
if self.policy.allowed_extensions and ext not in self.policy.allowed_extensions:
return False
return True
def validate_file_size(self, size: int) -> bool:
"""验证文件大小"""
return size <= self.policy.max_file_size
def validate_input_text(self, text: str, max_length: int = 10000) -> bool:
"""验证输入文本"""
if not text or len(text) > max_length:
return False
# 检查潜在的注入攻击
dangerous_patterns = [
r'<script[^>]*>.*?</script>', # XSS
r'javascript:', # JavaScript协议
r'on\w+\s*=', # 事件处理器
r'\b(exec|eval|system|shell|open)\s*\(', # 代码执行
r'\$\(.+?\)', # 命令替换
]
for pattern in dangerous_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False
return True
def sanitize_input(self, text: str) -> str:
"""清理输入文本"""
if not text:
return ""
# 移除潜在的HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 移除危险字符
dangerous_chars = ['\x00', '\r', '\n', '\t', '<', '>', '&', '"', "'", '\\']
for char in dangerous_chars:
text = text.replace(char, '')
return text.strip()
class AccessController:
"""访问控制器"""
def __init__(self, policy: SecurityPolicy):
self.policy = policy
self.access_log: List[Dict[str, Any]] = []
self.failed_attempts: Dict[str, int] = {}
self.blocked_ips: Set[str] = set(policy.blocked_ips)
self.logger = logging.getLogger(__name__)
def check_ip_access(self, ip_address: str) -> bool:
"""检查IP访问权限"""
# 检查是否在阻止列表中
if ip_address in self.blocked_ips:
return False
# 如果有允许列表,检查是否在允许列表中
if self.policy.allowed_ips and ip_address not in self.policy.allowed_ips:
return False
return True
def check_rate_limit(self, ip_address: str) -> bool:
"""检查速率限制"""
current_time = time.time()
window_start = current_time - self.policy.rate_limit_window
# 清理过期的尝试记录
self.access_log = [
entry for entry in self.access_log
if entry.get('timestamp', 0) > window_start
]
# 计算当前窗口内的请求数
recent_requests = sum(
1 for entry in self.access_log
if entry.get('ip') == ip_address
)
return recent_requests < self.policy.rate_limit_requests
def record_access(self, ip_address: str, user_id: Optional[str] = None, success: bool = True):
"""记录访问"""
self.access_log.append({
'ip': ip_address,
'user_id': user_id,
'timestamp': time.time(),
'success': success
})
if not success:
self.failed_attempts[ip_address] = self.failed_attempts.get(ip_address, 0) + 1
# 如果失败次数过多,阻止IP
if self.failed_attempts[ip_address] > 5:
self.blocked_ips.add(ip_address)
self.logger.warning(f"IP {ip_address} 因多次失败尝试被阻止")
def check_authentication(self, token: str) -> bool:
"""检查身份验证"""
if not self.policy.require_authentication:
return True
# 简单的令牌验证(实际应用中应使用更安全的方法)
if not token:
return False
# 这里可以实现JWT验证或其他令牌验证逻辑
# 为了演示,我们使用简单的哈希检查
expected_hash = hashlib.sha256(b"secret_token").hexdigest()
return hmac.compare_digest(
hashlib.sha256(token.encode()).hexdigest(),
expected_hash
)
class DataProtector:
"""数据保护器"""
def __init__(self, policy: SecurityPolicy):
self.policy = policy
self.encryption_key = secrets.token_bytes(32)
self.logger = logging.getLogger(__name__)
def encrypt_data(self, data: str) -> str:
"""加密数据"""
try:
import cryptography.fernet
fernet = cryptography.fernet.Fernet(
hashlib.sha256(self.encryption_key).digest()
)
return fernet.encrypt(data.encode()).decode()
except ImportError:
# 如果没有cryptography包,使用简单的编码
self.logger.warning("cryptography包未安装,使用简单编码代替")
import base64
return base64.b64encode(data.encode()).decode()
def decrypt_data(self, encrypted_data: str) -> str:
"""解密数据"""
try:
import cryptography.fernet
fernet = cryptography.fernet.Fernet(
hashlib.sha256(self.encryption_key).digest()
)
return fernet.decrypt(encrypted_data.encode()).decode()
except ImportError:
# 如果没有cryptography包,使用简单的解码
import base64
return base64.b64decode(encrypted_data.encode()).decode()
def hash_data(self, data: str, salt: Optional[str] = None) -> str:
"""哈希数据"""
if salt is None:
salt = secrets.token_hex(16)
return hashlib.pbkdf2_hmac(
'sha256',
data.encode(),
salt.encode(),
100000
).hex()
def generate_token(self, expires_in: int = 3600) -> str:
"""生成令牌"""
payload = {
'exp': time.time() + expires_in,
'iat': time.time(),
'jti': secrets.token_urlsafe(32)
}
# 简单的令牌生成(实际应用中应使用JWT)
token_data = json.dumps(payload).encode()
signature = hmac.new(
self.encryption_key,
token_data,
hashlib.sha256
).hexdigest()
return f"{signature}.{secrets.token_urlsafe(32)}"
class SecurityAuditor:
"""安全审计器"""
def __init__(self, policy: SecurityPolicy):
self.policy = policy
self.security_events: List[SecurityEvent] = []
self.logger = logging.getLogger(__name__)
def log_security_event(self, event: SecurityEvent) -> None:
"""记录安全事件"""
self.security_events.append(event)
if self.policy.audit_logging:
log_level = {
SecurityLevel.LOW: logging.INFO,
SecurityLevel.MEDIUM: logging.WARNING,
SecurityLevel.HIGH: logging.ERROR,
SecurityLevel.CRITICAL: logging.CRITICAL
}.get(event.severity, logging.INFO)
self.logger.log(
log_level,
f"安全事件: {event.event_type} - {event.description}",
extra={
'threat_type': event.threat_type.value,
'source_ip': event.source_ip,
'user_id': event.user_id,
'details': event.details
}
)
def detect_threats(self, data: Dict[str, Any]) -> List[SecurityEvent]:
"""检测威胁"""
events = []
# 检测路径遍历
if 'path' in data:
path = data['path']
if '..' in path or path.startswith('/'):
events.append(SecurityEvent(
event_type="path_traversal_attempt",
threat_type=ThreatType.PATH_TRAVERSAL,
description=f"检测到路径遍历尝试: {path}",
severity=SecurityLevel.HIGH,
source_ip=data.get('source_ip'),
user_id=data.get('user_id'),
details={'path': path},
blocked=True,
action_taken="blocked"
))
# 检测注入攻击
if 'input' in data:
input_text = data['input']
injection_patterns = [
r'\b(exec|eval|system|shell)\s*\(',
r';\s*(rm|del|format)\s+',
r'<script[^>]*>',
r'javascript:',
]
for pattern in injection_patterns:
if re.search(pattern, input_text, re.IGNORECASE):
events.append(SecurityEvent(
event_type="injection_attempt",
threat_type=ThreatType.INJECTION,
description=f"检测到注入尝试: {input_text[:50]}...",
severity=SecurityLevel.CRITICAL,
source_ip=data.get('source_ip'),
user_id=data.get('user_id'),
details={'input': input_text, 'pattern': pattern},
blocked=True,
action_taken="blocked"
))
return events
def generate_security_report(self) -> Dict[str, Any]:
"""生成安全报告"""
# 按威胁类型分组
threat_counts = {}
for event in self.security_events:
threat_type = event.threat_type.value
threat_counts[threat_type] = threat_counts.get(threat_type, 0) + 1
# 按严重程度分组
severity_counts = {}
for event in self.security_events:
severity = event.severity.value
severity_counts[severity] = severity_counts.get(severity, 0) + 1
# 最近事件
recent_events = sorted(
self.security_events,
key=lambda e: e.timestamp,
reverse=True
)[:10]
return {
'total_events': len(self.security_events),
'threat_distribution': threat_counts,
'severity_distribution': severity_counts,
'recent_events': [
{
'event_type': e.event_type,
'threat_type': e.threat_type.value,
'description': e.description,
'severity': e.severity.value,
'timestamp': e.timestamp,
'blocked': e.blocked
}
for e in recent_events
],
'generated_at': datetime.now().isoformat()
}
class SecurityManager:
"""安全管理器"""
def __init__(self, policy: Optional[SecurityPolicy] = None):
self.policy = policy or SecurityPolicy("default")
self.validator = InputValidator(self.policy)
self.access_controller = AccessController(self.policy)
self.data_protector = DataProtector(self.policy)
self.auditor = SecurityAuditor(self.policy)
self.logger = logging.getLogger(__name__)
def validate_request(self, request_data: Dict[str, Any]) -> List[SecurityEvent]:
"""验证请求"""
events = []
# 检测威胁
events.extend(self.auditor.detect_threats(request_data))
# 验证路径
if 'path' in request_data:
if not self.validator.validate_path(
request_data['path'],
request_data.get('base_path')
):
events.append(SecurityEvent(
event_type="invalid_path",
threat_type=ThreatType.PATH_TRAVERSAL,
description="无效的路径格式",
severity=SecurityLevel.MEDIUM,
source_ip=request_data.get('source_ip'),
user_id=request_data.get('user_id'),
details={'path': request_data['path']},
blocked=True,
action_taken="blocked"
))
# 验证文件
if 'filename' in request_data:
if not self.validator.validate_file_extension(request_data['filename']):
events.append(SecurityEvent(
event_type="invalid_file_type",
threat_type=ThreatType.DATA_EXFILTRATION,
description="不允许的文件类型",
severity=SecurityLevel.MEDIUM,
source_ip=request_data.get('source_ip'),
user_id=request_data.get('user_id'),
details={'filename': request_data['filename']},
blocked=True,
action_taken="blocked"
))
# 验证文件大小
if 'file_size' in request_data:
if not self.validator.validate_file_size(request_data['file_size']):
events.append(SecurityEvent(
event_type="file_too_large",
threat_type=ThreatType.DOS,
description="文件大小超出限制",
severity=SecurityLevel.MEDIUM,
source_ip=request_data.get('source_ip'),
user_id=request_data.get('user_id'),
details={'file_size': request_data['file_size']},
blocked=True,
action_taken="blocked"
))
# 记录事件
for event in events:
self.auditor.log_security_event(event)
return events
def check_access(self, ip_address: str, user_id: Optional[str] = None, token: Optional[str] = None) -> bool:
"""检查访问权限"""
# 检查IP访问权限
if not self.access_controller.check_ip_access(ip_address):
self.auditor.log_security_event(SecurityEvent(
event_type="ip_blocked",
threat_type=ThreatType.UNAUTHORIZED_ACCESS,
description=f"IP访问被拒绝: {ip_address}",
severity=SecurityLevel.MEDIUM,
source_ip=ip_address,
user_id=user_id,
blocked=True,
action_taken="blocked"
))
return False
# 检查速率限制
if not self.access_controller.check_rate_limit(ip_address):
self.auditor.log_security_event(SecurityEvent(
event_type="rate_limit_exceeded",
threat_type=ThreatType.DOS,
description=f"超过速率限制: {ip_address}",
severity=SecurityLevel.HIGH,
source_ip=ip_address,
user_id=user_id,
blocked=True,
action_taken="blocked"
))
return False
# 检查身份验证
if not self.access_controller.check_authentication(token):
self.auditor.log_security_event(SecurityEvent(
event_type="authentication_failed",
threat_type=ThreatType.UNAUTHORIZED_ACCESS,
description="身份验证失败",
severity=SecurityLevel.HIGH,
source_ip=ip_address,
user_id=user_id,
blocked=True,
action_taken="blocked"
))
return False
# 记录成功访问
self.access_controller.record_access(ip_address, user_id, True)
return True
def protect_data(self, data: str, operation: str = "encrypt") -> str:
"""保护数据"""
if operation == "encrypt":
return self.data_protector.encrypt_data(data)
elif operation == "decrypt":
return self.data_protector.decrypt_data(data)
elif operation == "hash":
return self.data_protector.hash_data(data)
else:
raise ValueError(f"不支持的操作: {operation}")
def generate_token(self, expires_in: int = 3600) -> str:
"""生成安全令牌"""
return self.data_protector.generate_token(expires_in)
def get_security_status(self) -> Dict[str, Any]:
"""获取安全状态"""
return {
'policy': asdict(self.policy),
'security_report': self.auditor.generate_security_report(),
'blocked_ips': list(self.access_controller.blocked_ips),
'system_info': {
'platform': get_system_info().platform.value,
'current_time': datetime.now().isoformat()
}
}
# 全局安全管理器实例
global_security_manager = SecurityManager()
def validate_request(request_data: Dict[str, Any]) -> List[SecurityEvent]:
"""验证请求便捷函数"""
return global_security_manager.validate_request(request_data)
def check_access(ip_address: str, user_id: Optional[str] = None, token: Optional[str] = None) -> bool:
"""检查访问权限便捷函数"""
return global_security_manager.check_access(ip_address, user_id, token)
def protect_data(data: str, operation: str = "encrypt") -> str:
"""保护数据便捷函数"""
return global_security_manager.protect_data(data, operation)
def generate_security_token(expires_in: int = 3600) -> str:
"""生成安全令牌便捷函数"""
return global_security_manager.generate_token(expires_in)
def get_security_status() -> Dict[str, Any]:
"""获取安全状态便捷函数"""
return global_security_manager.get_security_status()