"""
安全模块测试用例
"""
import pytest
import os
import tempfile
from pathlib import Path
# 添加源码路径
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
from utils.security import (
SecurityLevel,
ThreatType,
SecurityEvent,
SecurityPolicy,
SecurityManager,
InputValidator,
AccessController,
DataProtector,
SecurityAuditor
)
class TestSecurityPolicy:
"""安全策略测试"""
def test_security_policy_creation(self):
"""测试安全策略创建"""
policy = SecurityPolicy("test")
assert policy.name == "test"
assert policy.enabled is True
assert policy.level == SecurityLevel.MEDIUM
assert policy.max_file_size == 100 * 1024 * 1024
assert '.py' in policy.allowed_extensions
assert '.exe' in policy.blocked_extensions
def test_security_policy_custom_values(self):
"""测试自定义安全策略"""
policy = SecurityPolicy(
name="custom",
enabled=False,
level=SecurityLevel.HIGH,
max_file_size=50 * 1024 * 1024
)
assert policy.enabled is False
assert policy.level == SecurityLevel.HIGH
assert policy.max_file_size == 50 * 1024 * 1024
class TestInputValidator:
"""输入验证器测试"""
def setup_method(self):
"""设置测试环境"""
self.policy = SecurityPolicy("test")
self.validator = InputValidator(self.policy)
def test_validate_path_valid(self):
"""测试有效路径验证"""
valid_paths = [
"src/main.py",
"docs/readme.md",
"config/templates",
"src/utils/__init__.py"
]
for path in valid_paths:
assert self.validator.validate_path(path) is True
def test_validate_path_invalid(self):
"""测试无效路径验证"""
invalid_paths = [
"../../../etc/passwd",
"..\\..\\windows\\system32\\config",
"",
None
]
for path in invalid_paths:
assert self.validator.validate_path(path) is False
def test_validate_file_extension_allowed(self):
"""测试允许的文件扩展名"""
allowed_files = [
"main.py",
"config.json",
"readme.md",
"script.js"
]
for filename in allowed_files:
assert self.validator.validate_file_extension(filename) is True
def test_validate_file_extension_blocked(self):
"""测试阻止的文件扩展名"""
blocked_files = [
"malware.exe",
"script.bat",
"auto.cmd",
"install.sh"
]
for filename in blocked_files:
assert self.validator.validate_file_extension(filename) is False
def test_validate_file_size(self):
"""测试文件大小验证"""
# 有效大小
assert self.validator.validate_file_size(1024) is True
assert self.validator.validate_file_size(10 * 1024 * 1024) is True
# 无效大小(超过限制)
assert self.validator.validate_file_size(200 * 1024 * 1024) is False
def test_validate_input_text_valid(self):
"""测试有效输入文本验证"""
valid_texts = [
"Hello, world!",
"This is a test message.",
"Python is awesome"
]
for text in valid_texts:
assert self.validator.validate_input_text(text) is True
def test_validate_input_text_invalid(self):
"""测试无效输入文本验证"""
invalid_texts = [
"<script>alert('xss')</script>",
"javascript:void(0)",
"onclick=alert('xss')",
"eval('malicious code')",
""
]
for text in invalid_texts:
assert self.validator.validate_input_text(text) is False
def test_sanitize_input(self):
"""测试输入清理"""
dirty_input = "<script>alert('xss')</script>Hello, world!"
clean_input = self.validator.sanitize_input(dirty_input)
assert "<script>" not in clean_input
assert "Hello, world!" in clean_input
class TestAccessController:
"""访问控制器测试"""
def setup_method(self):
"""设置测试环境"""
self.policy = SecurityPolicy("test")
self.controller = AccessController(self.policy)
def test_check_ip_access_allowed(self):
"""测试允许的IP访问"""
allowed_ip = "192.168.1.100"
assert self.controller.check_ip_access(allowed_ip) is True
def test_check_ip_access_blocked(self):
"""测试阻止的IP访问"""
blocked_ip = "10.0.0.1"
self.controller.blocked_ips.add(blocked_ip)
assert self.controller.check_ip_access(blocked_ip) is False
def test_check_rate_limit(self):
"""测试速率限制"""
test_ip = "192.168.1.100"
# 初始访问应该允许
assert self.controller.check_rate_limit(test_ip) is True
# 模拟大量访问
for _ in range(self.policy.rate_limit_requests):
self.controller.record_access(test_ip)
# 下一次访问应该被限制
assert self.controller.check_rate_limit(test_ip) is False
def test_check_authentication_disabled(self):
"""测试禁用身份验证"""
self.policy.require_authentication = False
assert self.controller.check_authentication("any_token") is True
def test_record_access_success(self):
"""测试记录成功访问"""
test_ip = "192.168.1.100"
self.controller.record_access(test_ip, success=True)
assert test_ip not in self.controller.failed_attempts
def test_record_access_failure(self):
"""测试记录失败访问"""
test_ip = "192.168.1.100"
# 记录多次失败
for _ in range(6):
self.controller.record_access(test_ip, success=False)
# 应该被阻止
assert test_ip in self.controller.blocked_ips
class TestDataProtector:
"""数据保护器测试"""
def setup_method(self):
"""设置测试环境"""
self.policy = SecurityPolicy("test")
self.protector = DataProtector(self.policy)
def test_encrypt_decrypt_data(self):
"""测试数据加密解密"""
original_data = "This is sensitive information"
# 加密
encrypted = self.protector.encrypt_data(original_data)
assert encrypted != original_data
# 解密
decrypted = self.protector.decrypt_data(encrypted)
assert decrypted == original_data
def test_hash_data(self):
"""测试数据哈希"""
data = "password123"
hash_value = self.protector.hash_data(data)
# 哈希值应该不同且一致
assert hash_value != data
assert len(hash_value) == 64 # SHA256 hex length
# 相同输入应该产生相同哈希
hash_value2 = self.protector.hash_data(data)
assert hash_value == hash_value2
def test_generate_token(self):
"""测试生成令牌"""
token = self.protector.generate_token()
assert isinstance(token, str)
assert len(token) > 0
# 令牌应该不同
token2 = self.protector.generate_token()
assert token != token2
class TestSecurityAuditor:
"""安全审计器测试"""
def setup_method(self):
"""设置测试环境"""
self.policy = SecurityPolicy("test")
self.auditor = SecurityAuditor(self.policy)
def test_log_security_event(self):
"""测试记录安全事件"""
event = SecurityEvent(
event_type="test_event",
threat_type=ThreatType.PATH_TRAVERSAL,
description="Test security event",
severity=SecurityLevel.MEDIUM
)
self.auditor.log_security_event(event)
assert len(self.auditor.security_events) == 1
assert self.auditor.security_events[0].event_type == "test_event"
def test_detect_threats_path_traversal(self):
"""测试检测路径遍历威胁"""
data = {
'path': '../../../etc/passwd',
'source_ip': '192.168.1.100'
}
events = self.auditor.detect_threats(data)
assert len(events) == 1
assert events[0].threat_type == ThreatType.PATH_TRAVERSAL
assert events[0].blocked is True
def test_detect_threats_injection(self):
"""测试检测注入威胁"""
data = {
'input': '<script>alert("xss")</script>',
'source_ip': '192.168.1.100'
}
events = self.auditor.detect_threats(data)
assert len(events) >= 1
assert any(event.threat_type == ThreatType.INJECTION for event in events)
def test_generate_security_report(self):
"""测试生成安全报告"""
# 添加一些测试事件
event1 = SecurityEvent(
event_type="test_event1",
threat_type=ThreatType.PATH_TRAVERSAL,
description="Test event 1",
severity=SecurityLevel.MEDIUM
)
event2 = SecurityEvent(
event_type="test_event2",
threat_type=ThreatType.INJECTION,
description="Test event 2",
severity=SecurityLevel.HIGH
)
self.auditor.log_security_event(event1)
self.auditor.log_security_event(event2)
report = self.auditor.generate_security_report()
assert 'total_events' in report
assert report['total_events'] == 2
assert 'threat_distribution' in report
assert 'severity_distribution' in report
assert 'recent_events' in report
class TestSecurityManager:
"""安全管理器测试"""
def setup_method(self):
"""设置测试环境"""
self.policy = SecurityPolicy("test")
self.manager = SecurityManager(self.policy)
def test_validate_request_valid(self):
"""测试验证有效请求"""
request_data = {
'path': 'src/main.py',
'filename': 'main.py',
'file_size': 1024,
'source_ip': '192.168.1.100'
}
events = self.manager.validate_request(request_data)
# 有效请求应该没有安全事件
assert len(events) == 0
def test_validate_request_invalid_path(self):
"""测试验证无效路径请求"""
request_data = {
'path': '../../../etc/passwd',
'filename': 'passwd',
'file_size': 1024,
'source_ip': '192.168.1.100'
}
events = self.manager.validate_request(request_data)
# 应该有路径遍历安全事件
assert len(events) > 0
assert any(event.threat_type == ThreatType.PATH_TRAVERSAL for event in events)
def test_validate_request_invalid_file(self):
"""测试验证无效文件请求"""
request_data = {
'path': 'temp/malware.exe',
'filename': 'malware.exe',
'file_size': 1024,
'source_ip': '192.168.1.100'
}
events = self.manager.validate_request(request_data)
# 应该有文件类型安全事件
assert len(events) > 0
assert any(event.threat_type == ThreatType.DATA_EXFILTRATION for event in events)
def test_check_access_valid(self):
"""测试有效访问检查"""
ip_address = "192.168.1.100"
assert self.manager.check_access(ip_address) is True
def test_check_access_blocked_ip(self):
"""测试阻止IP访问检查"""
ip_address = "10.0.0.1"
self.manager.access_controller.blocked_ips.add(ip_address)
assert self.manager.check_access(ip_address) is False
def test_protect_data_encrypt(self):
"""测试数据保护加密"""
original_data = "sensitive information"
encrypted = self.manager.protect_data(original_data, "encrypt")
assert encrypted != original_data
decrypted = self.manager.protect_data(encrypted, "decrypt")
assert decrypted == original_data
def test_generate_security_token(self):
"""测试生成安全令牌"""
token = self.manager.generate_security_token()
assert isinstance(token, str)
assert len(token) > 0
def test_get_security_status(self):
"""测试获取安全状态"""
status = self.manager.get_security_status()
assert 'policy' in status
assert 'security_report' in status
assert 'blocked_ips' in status
assert 'system_info' in status
if __name__ == "__main__":
pytest.main([__file__, "-v"])