test_security.py•20.9 kB
"""Security tests for claude-memory MCP server
Tests path traversal prevention, SQL injection prevention, input validation,
audit logging, and access control.
Target: 100% coverage for security-critical code
"""
import pytest
import sqlite3
from pathlib import Path
@pytest.mark.security
class TestPathSecurity:
"""Test path traversal and file access security"""
async def test_claudemd_path_whitelist_etc_passwd(self, memory_server, tmp_path):
"""Test CLAUDE.md update blocks /etc/passwd"""
malicious_path = "/etc/passwd"
result = await memory_server._update_claudemd(
malicious_path,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
assert "not in allowed directories" in result.get("error", "").lower() or \
"only CLAUDE.md" in result.get("error", "")
async def test_claudemd_path_whitelist_windows_system32(self, memory_server):
"""Test CLAUDE.md update blocks Windows system files"""
malicious_path = "C:\\Windows\\System32\\config\\SAM"
result = await memory_server._update_claudemd(
malicious_path,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_claudemd_filename_restriction_python_file(self, memory_server, tmp_path):
"""Test only CLAUDE.md files can be updated, not arbitrary files"""
malicious_file = tmp_path / "malicious.py"
result = await memory_server._update_claudemd(
str(malicious_file),
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
assert "only CLAUDE.md" in result.get("error", "").lower() or \
"not in allowed" in result.get("error", "").lower()
async def test_claudemd_filename_restriction_shell_script(self, memory_server, tmp_path):
"""Test shell script files are blocked"""
malicious_file = tmp_path / "evil.sh"
result = await memory_server._update_claudemd(
str(malicious_file),
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_path_traversal_prevention_relative(self, memory_server, tmp_path):
"""Test relative path traversal (../..) is blocked"""
traversal_path = tmp_path / ".." / ".." / "etc" / "CLAUDE.md"
result = await memory_server._update_claudemd(
str(traversal_path),
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_path_traversal_prevention_absolute(self, memory_server):
"""Test absolute path traversal is blocked"""
traversal_path = "/../../../../etc/passwd"
result = await memory_server._update_claudemd(
traversal_path,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_symlink_attack_prevention(self, memory_server, tmp_path):
"""Test symlink pointing to sensitive files is blocked"""
sensitive_file = tmp_path / "sensitive.txt"
sensitive_file.write_text("sensitive data")
symlink = tmp_path / "CLAUDE.md"
try:
symlink.symlink_to(sensitive_file)
except OSError:
pytest.skip("Cannot create symlinks on this system")
# Should reject symlink
result = await memory_server._update_claudemd(
str(symlink),
project_path=None,
min_confidence=0.7
)
# Either blocks symlink or safely handles it
assert "success" in result
async def test_null_byte_injection(self, memory_server, malicious_inputs):
"""Test null byte injection in file paths"""
null_byte_path = malicious_inputs["null_byte"]
result = await memory_server._update_claudemd(
null_byte_path,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_whitelist_allows_temp_directory(self, memory_server, tmp_path):
"""Test whitelisted temp directory CLAUDE.md is allowed"""
valid_file = tmp_path / "CLAUDE.md"
valid_file.write_text("# Test")
result = await memory_server._update_claudemd(
str(valid_file),
project_path=str(tmp_path),
min_confidence=0.7
)
# Should succeed (or fail gracefully without security error)
assert "success" in result
async def test_directory_creation_security(self, memory_server, tmp_path):
"""Test directory creation doesn't escape whitelist"""
nested_path = tmp_path / "../../../etc/CLAUDE.md"
result = await memory_server._update_claudemd(
str(nested_path),
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
@pytest.mark.security
class TestSQLInjection:
"""Test SQL injection prevention"""
async def test_search_sql_injection_drop_table(self, memory_server):
"""Test search query with DROP TABLE injection"""
malicious_query = "'; DROP TABLE episodes; --"
result = await memory_server._search_episodes(
malicious_query,
limit=10
)
# Should not crash
assert "success" in result
# Verify table still exists
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='episodes'"
)
assert cursor.fetchone() is not None
async def test_search_sql_injection_union_select(self, memory_server):
"""Test search with UNION SELECT injection"""
malicious_query = "' UNION SELECT NULL, password, NULL FROM users; --"
result = await memory_server._search_episodes(
malicious_query,
limit=10
)
assert "success" in result or "error" in result
# Should not expose database structure
async def test_add_episode_sql_injection_content(self, memory_server):
"""Test episode addition with SQL injection in content"""
malicious_content = "'; DROP TABLE episodes; --"
result = await memory_server._add_episode(
"SQL Injection Test",
malicious_content,
"security-test"
)
assert result["success"] is True
# Verify table still exists and data is safely stored
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute("SELECT COUNT(*) FROM episodes")
count = cursor.fetchone()[0]
assert count > 0
# Verify malicious content is stored as-is (not executed)
cursor = conn.execute(
"SELECT content FROM episodes WHERE name = ?",
("SQL Injection Test",)
)
row = cursor.fetchone()
assert row is not None
assert malicious_content in row[0]
async def test_add_episode_sql_injection_name(self, memory_server):
"""Test episode addition with SQL injection in name"""
malicious_name = "Test'; DROP TABLE tool_logs; --"
result = await memory_server._add_episode(
malicious_name,
"Regular content",
"security-test"
)
assert result["success"] is True
# Verify tool_logs table still exists
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='tool_logs'"
)
assert cursor.fetchone() is not None
async def test_search_sql_injection_comment(self, memory_server):
"""Test SQL injection with comment syntax"""
malicious_query = "test' -- "
result = await memory_server._search_episodes(
malicious_query,
limit=10
)
assert "success" in result or "error" in result
async def test_parameterized_query_verification(self, memory_server):
"""Test that parameterized queries are used (positive test)"""
# Add episode with special chars that would break string concat
special_chars = "O'Reilly's \"Advanced\" SQL & NoSQL"
result = await memory_server._add_episode(
special_chars,
"Test content",
"test"
)
assert result["success"] is True
# Should be able to search for it
search_result = await memory_server._search_episodes("O'Reilly", limit=10)
assert search_result.get("success", False) or search_result.get("count", 0) >= 0
async def test_batch_sql_injection(self, memory_server):
"""Test multiple SQL injection attempts in sequence"""
injections = [
"'; DELETE FROM episodes WHERE '1'='1",
"1' OR '1'='1",
"'; UPDATE episodes SET content='hacked' WHERE '1'='1",
"admin'--",
"1' UNION ALL SELECT NULL,NULL,NULL,NULL,NULL--"
]
for injection in injections:
result = await memory_server._search_episodes(injection, limit=10)
assert "success" in result or "error" in result
# Verify database integrity
with sqlite3.connect(memory_server.db_path) as conn:
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
assert len(tables) >= 3 # episodes, tool_logs, audit_log
@pytest.mark.security
class TestAuditLogging:
"""Test audit trail for sensitive operations"""
async def test_claudemd_update_success_logged(self, memory_server, tmp_path):
"""Test successful CLAUDE.md updates are logged"""
file_path = tmp_path / "CLAUDE.md"
result = await memory_server._update_claudemd(
str(file_path),
project_path=str(tmp_path),
min_confidence=0.7
)
# Check audit log exists
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute("""
SELECT action, target_type, target_path, success
FROM audit_log
WHERE action = 'update_claudemd'
ORDER BY timestamp DESC
LIMIT 1
""")
log_entry = cursor.fetchone()
if log_entry:
assert log_entry[0] == "update_claudemd"
assert log_entry[1] == "file"
assert str(file_path) in str(log_entry[2])
async def test_failed_operations_logged(self, memory_server):
"""Test failed security attempts are logged"""
malicious_path = "/etc/passwd"
await memory_server._update_claudemd(
malicious_path,
project_path=None,
min_confidence=0.7
)
# Verify logged as failed
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute("""
SELECT action, success
FROM audit_log
WHERE target_path LIKE '%passwd%'
ORDER BY timestamp DESC
LIMIT 1
""")
log_entry = cursor.fetchone()
if log_entry:
assert log_entry[1] == 0 # False/failed
async def test_audit_log_completeness(self, memory_server, tmp_path):
"""Test audit log captures all required fields"""
file_path = tmp_path / "CLAUDE.md"
await memory_server._update_claudemd(
str(file_path),
project_path=str(tmp_path),
min_confidence=0.7
)
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute("""
SELECT timestamp, action, target_type, target_path, details, success
FROM audit_log
ORDER BY timestamp DESC
LIMIT 1
""")
log_entry = cursor.fetchone()
if log_entry:
# Verify all fields present
assert log_entry[0] is not None # timestamp
assert log_entry[1] is not None # action
assert log_entry[2] is not None # target_type
assert log_entry[3] is not None # target_path
async def test_audit_log_tampering_prevention(self, memory_server):
"""Test audit log cannot be easily tampered with"""
# Try to directly modify audit log
with sqlite3.connect(memory_server.db_path) as conn:
initial_count = conn.execute("SELECT COUNT(*) FROM audit_log").fetchone()[0]
# Create an audit entry
await memory_server._update_claudemd(
"/invalid/path",
project_path=None,
min_confidence=0.7
)
final_count = conn.execute("SELECT COUNT(*) FROM audit_log").fetchone()[0]
# Count should increase (entry was added)
assert final_count >= initial_count
async def test_audit_trail_chronological_order(self, memory_server, tmp_path):
"""Test audit entries maintain chronological order"""
# Create multiple operations
for i in range(3):
await memory_server._update_claudemd(
f"/tmp/test_{i}/CLAUDE.md",
project_path=None,
min_confidence=0.7
)
with sqlite3.connect(memory_server.db_path) as conn:
cursor = conn.execute("""
SELECT timestamp FROM audit_log
ORDER BY timestamp DESC
LIMIT 3
""")
timestamps = [row[0] for row in cursor.fetchall()]
# Should be in descending order
for i in range(len(timestamps) - 1):
assert timestamps[i] >= timestamps[i + 1]
@pytest.mark.security
class TestInputValidation:
"""Test input validation and sanitization"""
async def test_episode_name_length_limit(self, memory_server, malicious_inputs):
"""Test extremely long episode names are handled"""
long_name = malicious_inputs["long_input"]
result = await memory_server._add_episode(
long_name,
"Content",
"test"
)
# Should either accept (truncate) or reject gracefully
assert "success" in result or "error" in result
async def test_episode_content_length_limit(self, memory_server):
"""Test extremely long content is handled"""
long_content = "A" * 1000000 # 1MB
result = await memory_server._add_episode(
"Test",
long_content,
"test"
)
assert "success" in result or "error" in result
async def test_special_characters_handling(self, memory_server):
"""Test special characters in input are handled safely"""
special_chars = "!@#$%^&*(){}[]|\\:;\"'<>,.?/~`"
result = await memory_server._add_episode(
f"Special {special_chars}",
f"Content with {special_chars}",
"test"
)
assert result["success"] is True
async def test_unicode_handling(self, memory_server):
"""Test unicode characters are handled correctly"""
unicode_text = "测试 テスト тест 🚀 😀"
result = await memory_server._add_episode(
f"Unicode {unicode_text}",
unicode_text,
"test"
)
assert result["success"] is True
# Verify can retrieve
search_result = await memory_server._search_episodes(unicode_text[:4], 10)
assert search_result.get("success", False) or search_result.get("count", 0) >= 0
async def test_empty_input_handling(self, memory_server):
"""Test empty strings are validated"""
result = await memory_server._add_episode("", "Content", "test")
assert result["success"] is False or "error" in result
result = await memory_server._add_episode("Name", "", "test")
assert result["success"] is False or "error" in result
async def test_null_input_handling(self, memory_server):
"""Test null/None input validation"""
# This may raise TypeError or return error dict
try:
result = await memory_server._add_episode(None, "Content", "test")
assert result["success"] is False or "error" in result
except (TypeError, AttributeError):
pass # Acceptable to raise exception for None
async def test_whitespace_only_input(self, memory_server):
"""Test whitespace-only input is rejected"""
result = await memory_server._add_episode(" ", "Content", "test")
# Should either trim and reject, or reject as-is
assert "success" in result or "error" in result
async def test_xss_script_tag_sanitization(self, memory_server, malicious_inputs):
"""Test XSS attempts are safely stored (not executed)"""
xss_payload = malicious_inputs["xss"]
result = await memory_server._add_episode(
"XSS Test",
xss_payload,
"security-test"
)
assert result["success"] is True
# Verify stored safely
search_result = await memory_server._search_episodes("XSS Test", 10)
if search_result.get("success"):
# Content should be stored as-is, not interpreted
assert xss_payload in str(search_result)
@pytest.mark.security
class TestAccessControl:
"""Test file system access control and permissions"""
async def test_database_file_permissions(self, memory_server):
"""Test database file has appropriate permissions"""
db_path = Path(memory_server.db_path)
if db_path.exists():
stat_info = db_path.stat()
# On Unix, check file is not world-readable (would be 0o600 or 0o640)
# On Windows, this may not apply
mode = stat_info.st_mode & 0o777
# Just verify file exists and is accessible
assert db_path.exists()
async def test_no_world_writable_files_created(self, memory_server, tmp_path):
"""Test created files are not world-writable"""
test_file = tmp_path / "CLAUDE.md"
await memory_server._update_claudemd(
str(test_file),
project_path=str(tmp_path),
min_confidence=0.7
)
if test_file.exists():
stat_info = test_file.stat()
mode = stat_info.st_mode & 0o777
# Verify not world-writable (no 0o002 bit)
assert (mode & 0o002) == 0
async def test_directory_traversal_with_encoded_chars(self, memory_server):
"""Test URL-encoded path traversal is blocked"""
encoded_traversal = "%2e%2e%2f%2e%2e%2fetc%2fpasswd"
result = await memory_server._update_claudemd(
encoded_traversal,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_backslash_path_traversal(self, memory_server):
"""Test Windows-style backslash traversal is blocked"""
backslash_traversal = "..\\..\\..\\windows\\system32\\config"
result = await memory_server._update_claudemd(
backslash_traversal,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_mixed_slash_path_traversal(self, memory_server):
"""Test mixed forward/backward slash traversal is blocked"""
mixed_traversal = "../\\..\\/../etc/passwd"
result = await memory_server._update_claudemd(
mixed_traversal,
project_path=None,
min_confidence=0.7
)
assert result["success"] is False
async def test_environment_variable_protection(self, memory_server):
"""Test environment variables are not leaked in errors"""
result = await memory_server._update_claudemd(
"/invalid/path/CLAUDE.md",
project_path=None,
min_confidence=0.7
)
error_msg = result.get("error", "")
# Should not leak system paths or environment info
assert "HOME" not in error_msg
assert "USER" not in error_msg
assert "PATH" not in error_msg
# Summary of security test coverage:
# - Path Security: 10 tests
# - SQL Injection: 8 tests
# - Audit Logging: 5 tests
# - Input Validation: 8 tests
# - Access Control: 6 tests
# Total: 37 security tests