# tests/test_code_validator.py
"""
Unit tests for the code security validator.
These tests verify that the AST-based security validation correctly:
- Blocks dangerous code patterns
- Allows safe code patterns
- Validates HTTP methods
"""
import pytest
from fctr_okta_mcp.security.code_validator import (
validate_generated_code,
validate_http_method,
is_code_safe,
SecurityValidationResult,
)
class TestBlockedPatterns:
"""Tests for blocked dangerous patterns."""
def test_blocks_os_system(self):
"""Should block os.system() calls."""
code = 'os.system("rm -rf /")'
result = validate_generated_code(code)
assert not result.is_valid
assert len(result.violations) > 0 or len(result.blocked_patterns) > 0
def test_blocks_subprocess(self):
"""Should block subprocess module usage."""
code = 'subprocess.run(["ls", "-la"])'
result = validate_generated_code(code)
assert not result.is_valid
assert any("subprocess" in v.lower() for v in result.violations)
def test_blocks_exec(self):
"""Should block exec() function."""
code = 'exec("print(1)")'
result = validate_generated_code(code)
assert not result.is_valid
assert any("exec" in v.lower() for v in result.violations)
def test_blocks_eval(self):
"""Should block eval() function."""
code = 'eval("1 + 1")'
result = validate_generated_code(code)
assert not result.is_valid
assert any("eval" in v.lower() for v in result.violations)
def test_blocks_open(self):
"""Should block open() for file operations."""
code = 'open("/etc/passwd", "r")'
result = validate_generated_code(code)
assert not result.is_valid
assert any("open" in v.lower() for v in result.violations)
def test_blocks_import_os(self):
"""Should block import of os module."""
code = 'import os'
result = validate_generated_code(code)
assert not result.is_valid
assert any("os" in v.lower() for v in result.violations)
def test_blocks_from_import_subprocess(self):
"""Should block from subprocess import."""
code = 'from subprocess import run'
result = validate_generated_code(code)
assert not result.is_valid
assert any("subprocess" in v.lower() for v in result.violations)
def test_blocks_dunder_import(self):
"""Should block __import__() dynamic imports."""
code = '__import__("os")'
result = validate_generated_code(code)
assert not result.is_valid
def test_blocks_builtins_access(self):
"""Should block __builtins__ access."""
code = '__builtins__["eval"]("1+1")'
result = validate_generated_code(code)
assert not result.is_valid
def test_blocks_compile(self):
"""Should block compile() function."""
code = 'compile("print(1)", "<string>", "exec")'
result = validate_generated_code(code)
assert not result.is_valid
class TestAllowedPatterns:
"""Tests for allowed safe patterns."""
def test_allows_async_def(self):
"""Should allow async function definitions."""
code = '''
async def execute_query(client):
return []
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_await_make_request(self):
"""Should allow await client.make_request()."""
code = '''
async def execute_query(client):
result = await client.make_request("/api/v1/users", "GET", {})
return result
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_asyncio_gather(self):
"""Should allow asyncio.gather() for concurrent operations."""
code = '''
async def execute_query(client):
tasks = [client.make_request(f"/api/v1/users/{uid}", "GET", {}) for uid in user_ids]
results = await asyncio.gather(*tasks)
return results
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_list_comprehensions(self):
"""Should allow list comprehensions."""
code = '''
async def execute_query(client):
users = [u for u in data if u.get("status") == "ACTIVE"]
return users
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_dict_operations(self):
"""Should allow dict operations."""
code = '''
async def execute_query(client):
data = {"key": "value"}
data.update({"another": "item"})
return list(data.items())
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_json_operations(self):
"""Should allow json.loads/dumps."""
code = '''
async def execute_query(client):
data = json.loads('{"key": "value"}')
output = json.dumps(data)
return data
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_datetime_operations(self):
"""Should allow datetime operations."""
code = '''
async def execute_query(client):
now = datetime.now(timezone.utc)
week_ago = now - timedelta(days=7)
return week_ago.isoformat()
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_safe_builtins(self):
"""Should allow safe builtin functions."""
code = '''
async def execute_query(client):
items = [1, 2, 3, 4, 5]
return {
"length": len(items),
"sum": sum(items),
"min": min(items),
"max": max(items),
"sorted": sorted(items, reverse=True),
}
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_allows_string_methods(self):
"""Should allow string methods."""
code = '''
async def execute_query(client):
name = " John Doe "
return name.strip().lower().replace(" ", "_")
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
class TestHTTPMethodValidation:
"""Tests for HTTP method validation."""
def test_allows_get(self):
"""Should allow GET method."""
result = validate_http_method("GET")
assert result.is_valid
def test_allows_get_lowercase(self):
"""Should allow GET method in lowercase."""
result = validate_http_method("get")
assert result.is_valid
def test_blocks_post(self):
"""Should block POST method."""
result = validate_http_method("POST")
assert not result.is_valid
assert "POST" in result.violations[0]
def test_blocks_put(self):
"""Should block PUT method."""
result = validate_http_method("PUT")
assert not result.is_valid
def test_blocks_delete(self):
"""Should block DELETE method."""
result = validate_http_method("DELETE")
assert not result.is_valid
def test_blocks_patch(self):
"""Should block PATCH method."""
result = validate_http_method("PATCH")
assert not result.is_valid
class TestIsCodeSafe:
"""Tests for the is_code_safe() convenience function."""
def test_returns_true_for_safe_code(self):
"""Should return True for safe code."""
code = '''
async def execute_query(client):
return await client.make_request("/api/v1/users", "GET", {})
'''
assert is_code_safe(code)
def test_returns_false_for_dangerous_code(self):
"""Should return False for dangerous code."""
code = 'import os; os.system("echo pwned")'
assert not is_code_safe(code)
class TestEdgeCases:
"""Tests for edge cases and tricky patterns."""
def test_handles_syntax_error(self):
"""Should handle code with syntax errors."""
code = 'def broken(' # Invalid syntax
result = validate_generated_code(code)
assert not result.is_valid
assert any("syntax" in v.lower() for v in result.violations)
def test_handles_empty_code(self):
"""Should handle empty code."""
code = ''
result = validate_generated_code(code)
assert result.is_valid # Empty code is technically valid
def test_handles_whitespace_only(self):
"""Should handle whitespace-only code."""
code = ' \n\t \n'
result = validate_generated_code(code)
assert result.is_valid
def test_handles_comments_only(self):
"""Should handle comments-only code."""
code = '# This is a comment\n# Another comment'
result = validate_generated_code(code)
assert result.is_valid
def test_blocks_obfuscated_exec(self):
"""Should block obfuscated exec() attempts."""
code = 'exec ("print(1)")' # Extra space
result = validate_generated_code(code)
assert not result.is_valid
def test_blocks_nested_dangerous_import(self):
"""Should block dangerous import inside function."""
code = '''
async def execute_query(client):
import os
return os.getcwd()
'''
result = validate_generated_code(code)
assert not result.is_valid
class TestRealWorldPatterns:
"""Tests for patterns typical in generated Okta queries."""
def test_typical_user_query(self):
"""Should allow typical user listing query."""
code = '''
async def execute_query(client):
response = await client.make_request(
endpoint="/api/v1/users",
method="GET",
params={"search": 'status eq "ACTIVE"'},
entity_label="users"
)
if response.get("status") == "success":
users = response.get("data", [])
return [
{
"id": u.get("id"),
"email": u.get("profile", {}).get("email"),
"status": u.get("status")
}
for u in users
]
return []
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_typical_concurrent_query(self):
"""Should allow typical concurrent API calls."""
code = '''
async def execute_query(client):
# Get all users first
users_response = await client.make_request("/api/v1/users", "GET", {})
users = users_response.get("data", [])
# Get groups for each user concurrently
tasks = [
client.make_request(f"/api/v1/users/{u['id']}/groups", "GET", {})
for u in users[:10] # Limit to first 10
]
results = await asyncio.gather(*tasks)
all_data = []
for user, groups_response in zip(users[:10], results):
groups = groups_response.get("data", [])
all_data.append({
"user_id": user.get("id"),
"email": user.get("profile", {}).get("email"),
"groups": [g.get("profile", {}).get("name") for g in groups]
})
return all_data
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"
def test_typical_log_query(self):
"""Should allow typical log query with date filtering."""
code = '''
async def execute_query(client):
since = (datetime.now(timezone.utc) - timedelta(hours=24)).isoformat()
response = await client.make_request(
endpoint="/api/v1/logs",
method="GET",
params={
"since": since,
"filter": 'eventType eq "user.authentication.auth" and outcome.result eq "FAILURE"',
"limit": 100
}
)
events = response.get("data", [])
return [
{
"time": e.get("published"),
"actor": e.get("actor", {}).get("alternateId"),
"outcome": e.get("outcome", {}).get("result")
}
for e in events
]
'''
result = validate_generated_code(code)
assert result.is_valid, f"Violations: {result.violations}"