test_exploit_generator.py•18.7 kB
"""Corrected tests for exploit generator module with actual interfaces."""
import os
import sys
from unittest.mock import Mock
# Add the src directory to the path to import modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from adversary_mcp_server.credentials import SecurityConfig
from adversary_mcp_server.scanner.exploit_generator import (
ExploitContext,
ExploitGenerationError,
ExploitGenerator,
ExploitPrompt,
LLMExploitGenerator,
SafetyFilter,
TemplateEngine,
)
from adversary_mcp_server.scanner.types import Category, Severity, ThreatMatch
class TestExploitContext:
"""Test ExploitContext class."""
def test_exploit_context_initialization(self):
"""Test ExploitContext initialization."""
threat = ThreatMatch(
rule_id="test_rule",
rule_name="Test Rule",
description="Test",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=1,
)
context = ExploitContext(threat, "test code", {"key": "value"})
assert context.threat_match == threat
assert context.source_code == "test code"
assert context.additional_context == {"key": "value"}
def test_exploit_context_defaults(self):
"""Test ExploitContext with default values."""
threat = ThreatMatch(
rule_id="test_rule",
rule_name="Test Rule",
description="Test description",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="/test/path",
line_number=1,
)
context = ExploitContext(threat, source_code="test code")
assert context.threat_match == threat
assert context.source_code == "test code"
assert context.additional_context == {}
class TestSafetyFilterCorrected:
"""Test SafetyFilter with actual implementation."""
def test_safety_filter_initialization(self):
"""Test SafetyFilter initialization."""
filter = SafetyFilter()
assert hasattr(filter, "dangerous_patterns")
assert hasattr(filter, "compiled_patterns")
assert isinstance(filter.dangerous_patterns, list)
assert len(filter.dangerous_patterns) > 0
def test_filter_dangerous_commands(self):
"""Test filtering dangerous commands."""
filter = SafetyFilter()
# Test dangerous exploit
dangerous_exploit = "rm -rf /"
assert not filter.is_safe(dangerous_exploit)
# Test safe exploit
safe_exploit = "print('hello world')"
assert filter.is_safe(safe_exploit)
def test_filter_safe_replacements(self):
"""Test replacing dangerous commands with safe alternatives."""
filter = SafetyFilter()
dangerous_exploit = "rm -rf /"
safe_exploit = filter.sanitize_exploit(dangerous_exploit)
assert filter.is_safe(safe_exploit)
assert "This would delete files" in safe_exploit
def test_is_safe_exploit(self):
"""Test checking if exploit is safe."""
filter = SafetyFilter()
# Test SQL injection (should be safe)
assert filter.is_safe("' OR '1'='1' --")
# Test dangerous command (should not be safe)
assert not filter.is_safe("rm -rf /")
class TestTemplateEngineCorrected:
"""Test TemplateEngine with actual implementation."""
def test_template_engine_initialization(self):
"""Test TemplateEngine initialization."""
engine = TemplateEngine()
assert engine is not None
assert hasattr(engine, "templates")
assert hasattr(engine, "env")
def test_template_engine_has_default_templates(self):
"""Test template engine has default templates."""
engine = TemplateEngine()
# Check that basic templates exist
assert "sql_injection" in engine.templates
assert "xss" in engine.templates
assert "command_injection" in engine.templates
assert len(engine.templates) >= 3
def test_generate_from_template(self):
"""Test generating exploit from template."""
engine = TemplateEngine()
threat = ThreatMatch(
rule_id="sql_injection_test",
rule_name="SQL Injection",
description="Test SQL injection",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=10,
code_snippet="SELECT * FROM users WHERE id = " + str(42),
)
result = engine.generate_exploit(threat, "test code")
assert isinstance(result, list)
assert len(result) > 0
assert "' OR '1'='1" in result[0]
def test_add_template(self):
"""Test adding custom template."""
engine = TemplateEngine()
custom_template = "Custom exploit: {{ threat.rule_name }}"
engine.add_template("custom_test", custom_template)
assert "custom_test" in engine.templates
# Test that we can use the custom template
threat = ThreatMatch(
rule_id="custom_rule",
rule_name="Custom Rule",
description="Custom description",
category=Category.INJECTION,
severity=Severity.MEDIUM,
file_path="test.py",
line_number=5,
)
result = engine.generate_exploit(threat, "test code")
assert isinstance(result, list)
def test_template_with_missing_context(self):
"""Test template rendering with missing context."""
engine = TemplateEngine()
threat = ThreatMatch(
rule_id="incomplete_test",
rule_name="Incomplete Test",
description="Test with missing context",
category=Category.INJECTION,
severity=Severity.LOW,
file_path="test.py",
line_number=1,
)
# Should still generate something even with minimal context
result = engine.generate_exploit(threat, "")
assert isinstance(result, list)
assert len(result) > 0
class TestLLMExploitGeneratorCorrected:
"""Test LLMExploitGenerator with actual implementation."""
def test_llm_generator_initialization(self):
"""Test LLMExploitGenerator initialization."""
config = SecurityConfig(enable_llm_analysis=True, severity_threshold="high")
generator = LLMExploitGenerator(config)
assert generator.config == config
def test_llm_generator_without_api_key(self):
"""Test LLMExploitGenerator without API key."""
config = SecurityConfig(enable_llm_analysis=False)
generator = LLMExploitGenerator(config)
assert generator.config == config
def test_create_prompt(self):
"""Test creating exploit prompt."""
config = SecurityConfig(enable_llm_analysis=True)
generator = LLMExploitGenerator(config)
threat = ThreatMatch(
rule_id="test_rule",
rule_name="Test Injection",
description="Test vulnerability",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=42,
)
context = ExploitContext(threat, source_code="test code")
prompt = generator.create_exploit_prompt(context)
assert isinstance(prompt, ExploitPrompt)
assert prompt.threat_match == threat
assert prompt.source_code == "test code"
assert "Test Injection" in prompt.user_prompt
def test_generate_exploits_with_llm(self):
"""Test generating exploits with LLM."""
mock_config = SecurityConfig(enable_llm_analysis=True)
mock_credential_manager = Mock()
mock_credential_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_credential_manager)
threat = ThreatMatch(
rule_id="test_rule",
rule_name="Test Rule",
description="Test description",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=10,
)
exploits = generator.generate_exploits(threat, "test code", use_llm=False)
assert isinstance(exploits, list)
assert len(exploits) > 0
def test_is_llm_available(self):
"""Test checking if LLM is available."""
mock_config = SecurityConfig(enable_llm_analysis=True)
mock_credential_manager = Mock()
mock_credential_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_credential_manager)
# With client-based LLM, it should always be available
assert generator.is_llm_available() is True
def test_is_llm_available_disabled(self):
"""Test checking if LLM is available when disabled."""
mock_config = SecurityConfig(enable_llm_analysis=False)
mock_credential_manager = Mock()
mock_credential_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_credential_manager)
# With client-based LLM, it should always be available
assert generator.is_llm_available() is True
def test_parse_exploits(self):
"""Test parsing exploits from text."""
config = SecurityConfig()
generator = LLMExploitGenerator(config)
text = """
Here are some exploits:
```
' OR '1'='1' --
```
```
' UNION SELECT * FROM users --
```
"""
exploits = generator._parse_exploits(text)
assert isinstance(exploits, list)
assert len(exploits) >= 2
class TestExploitGeneratorCorrected:
"""Test ExploitGenerator with actual implementation."""
def test_exploit_generator_initialization(self):
"""Test ExploitGenerator initialization."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
assert generator.credential_manager == mock_manager
assert hasattr(generator, "safety_filter")
assert hasattr(generator, "template_engine")
assert hasattr(generator, "llm_generator")
def test_generate_exploits_with_template(self):
"""Test exploit generation using templates."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
threat = ThreatMatch(
rule_id="sql_injection",
rule_name="SQL Injection",
description="SQL injection vulnerability",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=1,
code_snippet="SELECT * FROM users WHERE id = " + "user_input",
)
# Generate exploits without LLM
result = generator.generate_exploits(threat, "test code", use_llm=False)
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(exploit, str) for exploit in result)
def test_generate_exploits_with_llm(self):
"""Test exploit generation with LLM (client-based)."""
mock_manager = Mock()
mock_config = SecurityConfig(enable_llm_analysis=True)
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
threat = ThreatMatch(
rule_id="xss",
rule_name="XSS",
description="XSS vulnerability",
category=Category.XSS,
severity=Severity.MEDIUM,
file_path="test.js",
line_number=1,
)
result = generator.generate_exploits(threat, "test code", use_llm=True)
assert isinstance(result, list)
def test_is_llm_available(self):
"""Test LLM availability check."""
mock_manager = Mock()
# With LLM enabled
mock_config = SecurityConfig(enable_llm_analysis=True)
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
assert generator.is_llm_available() is True
# With LLM disabled (but client-based LLM is always available)
mock_config = SecurityConfig(enable_llm_analysis=False)
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
assert generator.is_llm_available() is True
def test_get_exploit_metadata(self):
"""Test getting exploit metadata."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
threat = ThreatMatch(
rule_id="command_injection",
rule_name="Command Injection",
description="Command injection vulnerability",
category=Category.INJECTION,
severity=Severity.CRITICAL,
file_path="test.py",
line_number=1,
)
metadata = generator.get_exploit_metadata(threat)
assert isinstance(metadata, dict)
assert "category" in metadata
assert "severity" in metadata
# Check for actual keys returned by the method
assert "llm_available" in metadata or "available_templates" in metadata
def test_add_custom_template(self):
"""Test adding custom template."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
custom_template = "Custom {{ vulnerability }} exploit"
generator.add_custom_template("custom_test", custom_template)
# Should be added to template engine
assert "custom_test" in generator.template_engine.templates
def test_generate_from_templates_method(self):
"""Test _generate_from_templates method."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
threat = ThreatMatch(
rule_id="path_traversal",
rule_name="Path Traversal",
description="Path traversal vulnerability",
category=Category.LFI,
severity=Severity.HIGH,
file_path="test.py",
line_number=1,
)
result = generator._generate_from_templates(threat, "test code")
assert isinstance(result, list)
def test_get_available_templates(self):
"""Test getting available templates."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
templates = generator._get_available_templates(Category.INJECTION)
assert isinstance(templates, list)
def test_exploit_generation_error(self):
"""Test ExploitGenerationError."""
error = ExploitGenerationError("Test error")
assert str(error) == "Test error"
assert isinstance(error, Exception)
def test_safety_filtering_integration(self):
"""Test that safety filtering is applied."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
# Create a threat that might generate dangerous content
threat = ThreatMatch(
rule_id="command_injection",
rule_name="Command Injection",
description="Command injection vulnerability",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=1,
)
result = generator.generate_exploits(
threat, "os.system(user_input)", use_llm=False
)
# Results should be filtered for safety
assert isinstance(result, list)
for exploit in result:
assert isinstance(exploit, str)
# Should not contain extremely dangerous patterns
assert "rm -rf /" not in exploit
def test_different_vulnerability_categories(self):
"""Test exploit generation for different vulnerability categories."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
categories = [
Category.INJECTION,
Category.XSS,
Category.DESERIALIZATION,
Category.LFI,
]
for category in categories:
threat = ThreatMatch(
rule_id=f"test_{category.value}",
rule_name=f"Test {category.value}",
description=f"Test {category.value} vulnerability",
category=category,
severity=Severity.MEDIUM,
file_path="test.py",
line_number=1,
)
result = generator.generate_exploits(threat, "test code", use_llm=False)
assert isinstance(result, list)
def test_exploit_generation_with_context(self):
"""Test exploit generation with different contexts."""
mock_manager = Mock()
mock_config = SecurityConfig()
mock_manager.load_config.return_value = mock_config
generator = ExploitGenerator(mock_manager)
threat = ThreatMatch(
rule_id="sql_injection",
rule_name="SQL Injection",
description="SQL injection vulnerability",
category=Category.INJECTION,
severity=Severity.HIGH,
file_path="test.py",
line_number=1,
code_snippet="SELECT * FROM users WHERE id = " + "user_input",
)
# Test with different source code contexts
contexts = [
"SELECT * FROM users WHERE id = user_input",
"query = 'SELECT * FROM products WHERE name = ' + search_term",
"database.execute('SELECT * FROM orders WHERE customer_id = ' + customer_id)",
]
for context in contexts:
result = generator.generate_exploits(threat, context, use_llm=False)
assert isinstance(result, list)
assert len(result) > 0