# src/fctr_okta_mcp/security/code_validator.py
"""
Security Configuration for Okta MCP Server
AST-based security validation for generated Python code.
Implements defense-in-depth with:
- Blocked patterns detection (regex)
- AST-based import/function validation
- Module whitelisting
- Method whitelisting
"""
from typing import List, Set, Optional
import re
import ast
from dataclasses import dataclass
from fctr_okta_mcp.utils.logger import get_logger
logger = get_logger(__name__)
# Public API
__all__ = [
"SecurityValidationResult",
"CodeSecurityValidator",
"validate_generated_code",
"validate_http_method",
"is_code_safe",
]
# ------------------------------------------------------------------------
# Security Validation Result Classes
# ------------------------------------------------------------------------
@dataclass
class SecurityValidationResult:
"""Result of security validation"""
is_valid: bool
violations: List[str]
blocked_patterns: List[str]
risk_level: str # 'LOW', 'MEDIUM', 'HIGH', 'CRITICAL'
# ------------------------------------------------------------------------
# Core Security Patterns
# ------------------------------------------------------------------------
# Dangerous patterns that should never appear in generated code
# Use \s* to allow any whitespace before parentheses
BLOCKED_PATTERNS = [
r'os\.system\s*\(',
r'subprocess\.',
r'\bexec\s*\(', # exec() - code execution
r'\beval\s*\(', # eval() - code execution
r'__import__\s*\(', # Dynamic imports
r'\bopen\s*\(', # File I/O
r'\binput\s*\(', # User input (stdin)
r'\bfile\s*\(', # Legacy file() function
r'execfile\s*\(', # Python 2 execfile
r'\bcompile\s*\(', # Code compilation
r'\bglobals\s*\(', # Global namespace access
r'\blocals\s*\(', # Local namespace access
r'\bsetattr\s*\(', # Dynamic attribute setting
r'\bdelattr\s*\(', # Dynamic attribute deletion
r'reload\s*\(', # Module reloading
r'__builtins__', # Builtins access
r'__class__\.__bases__', # Class hierarchy traversal
r'__subclasses__', # Subclass enumeration
r'import_module\s*\(', # importlib.import_module
r'breakpoint\s*\(', # Debugger access
r'__.*__\s*\(', # Dunder method calls (magic methods)
]
# Allowed modules for generated code (can be imported if needed)
ALLOWED_MODULES: Set[str] = {
# Core Python modules (pre-injected in subprocess, but safe to import)
'asyncio',
'json',
'datetime',
'time',
'logging',
're',
'typing',
'collections',
'itertools',
'functools',
'math',
}
# Dangerous modules that should NEVER be imported
BLOCKED_MODULES: Set[str] = {
'os',
'sys',
'subprocess',
'shutil',
'socket',
'http',
'urllib',
'requests',
'pickle',
'shelve',
'marshal',
'importlib',
'builtins',
'__builtin__',
'ctypes',
'multiprocessing',
'threading',
'signal',
'pty',
'fcntl',
'resource',
}
# Safe built-in functions and classes
ALLOWED_BUILTINS: Set[str] = {
# Type constructors
'len', 'str', 'int', 'float', 'bool', 'list', 'dict', 'tuple', 'set',
# Iteration
'range', 'enumerate', 'zip', 'sorted', 'reversed',
# Aggregation
'sum', 'min', 'max', 'abs', 'round',
# Type checking
'isinstance', 'issubclass', 'type',
# Logical
'any', 'all',
# Iteration helpers
'iter', 'next', 'map', 'filter',
# Safe attribute access
'getattr', 'hasattr',
# Essential classes/functions for generated code
'OktaAPIClient', # Our API client class
'timedelta', 'datetime', 'timezone', # DateTime operations (pre-injected)
}
# Allowed Python methods for generated code
ALLOWED_PYTHON_METHODS: Set[str] = {
# JSON methods
'loads', 'dumps', 'load', 'dump', 'JSONDecodeError',
# Data structure operations
'items', 'keys', 'values', 'get', 'append', 'extend', 'insert',
'add', 'remove', 'pop', 'clear', 'index', 'count', 'sort', 'reverse',
'update', 'copy', 'setdefault',
# String methods
'join', 'split', 'strip', 'lstrip', 'rstrip', 'upper', 'lower',
'startswith', 'endswith', 'replace', 'format', 'encode', 'decode',
'find', 'rfind', 'isdigit', 'isalpha', 'isalnum',
# OktaAPIClient methods (our simplified client)
'make_request',
# Progress tracking methods
'start_entity_progress', 'update_entity_progress', 'complete_entity_progress',
# Datetime operations
'now', 'utcnow', 'strftime', 'strptime', 'isoformat', 'timestamp',
# Async operations
'gather', 'sleep', 'create_task',
}
# Only GET is allowed for Okta API (read-only)
ALLOWED_HTTP_METHODS: Set[str] = {'GET'}
# ------------------------------------------------------------------------
# Security Validator Class
# ------------------------------------------------------------------------
class CodeSecurityValidator:
"""
AST-based security validator for generated Python code.
Validates:
1. Blocked patterns (regex-based)
2. Import statements (module whitelist)
3. Function calls (builtin whitelist)
4. Method calls (method whitelist)
"""
def __init__(self):
self.blocked_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in BLOCKED_PATTERNS]
self.allowed_modules = ALLOWED_MODULES
self.allowed_builtins = ALLOWED_BUILTINS
self.allowed_methods = ALLOWED_PYTHON_METHODS
def validate_python_code(self, code: str) -> SecurityValidationResult:
"""
Validate Python code for security compliance.
Uses AST parsing to detect:
- Dangerous imports
- Unsafe function calls
- Blocked code patterns
Args:
code: Python code string to validate
Returns:
SecurityValidationResult with validation status and details
"""
violations = []
blocked_patterns = []
risk_level = 'LOW'
try:
# Parse AST to validate structure
tree = ast.parse(code)
# Check for dangerous patterns (regex-based)
for pattern in self.blocked_patterns:
if pattern.search(code):
violations.append(f"Blocked pattern detected: {pattern.pattern}")
blocked_patterns.append(pattern.pattern)
risk_level = 'HIGH'
# Validate imports and function calls (AST-based)
for node in ast.walk(tree):
# Check imports - block dangerous modules, allow safe ones
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0] # Get root module
if module_name in BLOCKED_MODULES:
violations.append(f"Blocked module import: {alias.name}")
risk_level = 'HIGH'
# Safe modules are allowed (either pre-injected or harmless)
elif isinstance(node, ast.ImportFrom):
module_name = (node.module or '').split('.')[0] # Get root module
if module_name in BLOCKED_MODULES:
violations.append(f"Blocked module import: {node.module}")
risk_level = 'HIGH'
# Safe modules are allowed
# Check function calls
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
func_name = node.func.id
# Only block if it matches dangerous patterns
if func_name not in self.allowed_builtins:
dangerous_patterns = ['exec', 'eval', 'compile', 'open', 'input', 'system']
if any(dangerous in func_name.lower() for dangerous in dangerous_patterns):
violations.append(f"Unauthorized function call: {func_name}")
risk_level = 'HIGH'
# Allow user-defined functions like fetch_data, process_users, etc.
elif isinstance(node.func, ast.Attribute):
attr_name = node.func.attr
if attr_name not in self.allowed_methods:
# Check if it's a dangerous method name
dangerous_patterns = ['exec', 'eval', 'compile', 'open', 'input', 'system']
if any(dangerous in attr_name.lower() for dangerous in dangerous_patterns):
violations.append(f"Unauthorized method call: {attr_name}")
risk_level = 'HIGH'
# Allow user-defined methods
except SyntaxError as e:
violations.append(f"Syntax error in code: {e}")
risk_level = 'HIGH'
except Exception as e:
violations.append(f"Code validation error: {e}")
risk_level = 'MEDIUM'
is_valid = len(violations) == 0
return SecurityValidationResult(is_valid, violations, blocked_patterns, risk_level)
def validate_http_method(self, method: str) -> SecurityValidationResult:
"""Validate HTTP method is allowed (GET only for read-only operations)"""
method_upper = method.upper()
if method_upper not in ALLOWED_HTTP_METHODS:
return SecurityValidationResult(
is_valid=False,
violations=[f"HTTP method '{method}' not allowed. Only GET is permitted."],
blocked_patterns=[],
risk_level='HIGH'
)
return SecurityValidationResult(
is_valid=True,
violations=[],
blocked_patterns=[],
risk_level='LOW'
)
# ------------------------------------------------------------------------
# Global Validator Instance and Public Functions
# ------------------------------------------------------------------------
# Create global validator instance
_security_validator = CodeSecurityValidator()
def validate_generated_code(code: str) -> SecurityValidationResult:
"""
Validate generated Python code for security compliance.
Args:
code: Python code to validate
Returns:
SecurityValidationResult with validation status and details
"""
return _security_validator.validate_python_code(code)
def validate_http_method(method: str) -> SecurityValidationResult:
"""Validate HTTP method is allowed"""
return _security_validator.validate_http_method(method)
def is_code_safe(code: str) -> bool:
"""
Simple wrapper for backward compatibility.
Use validate_generated_code() for detailed results.
Args:
code: Python code to validate
Returns:
True if code passes security validation
"""
result = validate_generated_code(code)
return result.is_valid