code_validator.py•11.4 kB
"""
Code validation utilities for AI-generated fixes
Provides basic syntax checking and security validation
"""
import ast
import re
import logging
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""Result of code validation"""
valid: bool
errors: List[str]
warnings: List[str]
language: str
class CodeValidator:
"""Validates AI-generated code for basic syntax and security"""
# Dangerous patterns to check for
DANGEROUS_PATTERNS = {
'python': [
r'exec\s*\(',
r'eval\s*\(',
r'__import__\s*\(',
r'subprocess\.',
r'os\.system',
r'os\.popen',
r'open\s*\([^)]*["\']w["\']', # File writes
r'rm\s+-rf',
r'DELETE\s+FROM', # SQL injection
r'DROP\s+TABLE',
],
'javascript': [
r'eval\s*\(',
r'Function\s*\(',
r'document\.write',
r'innerHTML\s*=',
r'outerHTML\s*=',
r'setTimeout\s*\([^)]*string',
r'setInterval\s*\([^)]*string',
],
'typescript': [
r'eval\s*\(',
r'Function\s*\(',
r'document\.write',
r'innerHTML\s*=',
r'outerHTML\s*=',
],
'java': [
r'Runtime\.getRuntime\(\)\.exec',
r'ProcessBuilder',
r'System\.exit',
r'Class\.forName',
]
}
def __init__(self):
self.max_file_size = 50000 # 50KB max
self.max_lines = 1000
async def validate_code(self, code: str, language: str = "python",
filename: Optional[str] = None) -> ValidationResult:
"""
Validate code for syntax and basic security
Args:
code: The code to validate
language: Programming language (python, javascript, typescript, java)
filename: Optional filename for context
Returns:
ValidationResult with validation status and any errors/warnings
"""
errors = []
warnings = []
try:
# Basic size checks
if len(code) > self.max_file_size:
errors.append(f"Code too large: {len(code)} bytes (max: {self.max_file_size})")
lines = code.split('\n')
if len(lines) > self.max_lines:
errors.append(f"Too many lines: {len(lines)} (max: {self.max_lines})")
# Language-specific validation
if language.lower() == "python":
result = await self._validate_python(code)
errors.extend(result.errors)
warnings.extend(result.warnings)
elif language.lower() in ["javascript", "js"]:
result = await self._validate_javascript(code)
errors.extend(result.errors)
warnings.extend(result.warnings)
elif language.lower() in ["typescript", "ts"]:
result = await self._validate_typescript(code)
errors.extend(result.errors)
warnings.extend(result.warnings)
else:
# Generic validation for other languages
result = await self._validate_generic(code, language)
errors.extend(result.errors)
warnings.extend(result.warnings)
# Security pattern checking
security_issues = self._check_security_patterns(code, language)
if security_issues:
errors.extend([f"Security issue: {issue}" for issue in security_issues])
# Check for common issues
common_issues = self._check_common_issues(code)
warnings.extend(common_issues)
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
language=language
)
except Exception as e:
logger.error(f"Code validation failed: {e}")
return ValidationResult(
valid=False,
errors=[f"Validation error: {str(e)}"],
warnings=[],
language=language
)
async def _validate_python(self, code: str) -> ValidationResult:
"""Validate Python code using AST"""
errors = []
warnings = []
try:
# Parse with AST
tree = ast.parse(code)
# Check for dangerous constructs
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ['exec', 'eval']:
errors.append(f"Dangerous function call: {node.func.id}")
elif node.func.id == 'open':
# Check if it's a write operation
if len(node.args) > 1:
if isinstance(node.args[1], ast.Constant):
if 'w' in str(node.args[1].value):
warnings.append("File write operation detected")
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name in ['subprocess', 'os']:
warnings.append(f"Potentially dangerous import: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module in ['subprocess', 'os']:
warnings.append(f"Potentially dangerous import from: {node.module}")
except SyntaxError as e:
errors.append(f"Python syntax error: {e.msg} at line {e.lineno}")
except Exception as e:
errors.append(f"Python validation error: {str(e)}")
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
language="python"
)
async def _validate_javascript(self, code: str) -> ValidationResult:
"""Basic JavaScript validation using regex patterns"""
errors = []
warnings = []
# Basic syntax checks
if not self._check_balanced_braces(code):
errors.append("Unbalanced braces in JavaScript code")
# Check for common syntax issues
if re.search(r'function\s+\w+\s*\([^)]*\)\s*{[^}]*$', code, re.MULTILINE):
errors.append("Unclosed function definition")
# Check for dangerous patterns
if re.search(r'eval\s*\(', code):
errors.append("Use of eval() is dangerous")
if re.search(r'innerHTML\s*=', code):
warnings.append("Direct innerHTML assignment can be unsafe")
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
language="javascript"
)
async def _validate_typescript(self, code: str) -> ValidationResult:
"""Basic TypeScript validation (similar to JavaScript)"""
# For now, use JavaScript validation as base
result = await self._validate_javascript(code)
result.language = "typescript"
# Additional TypeScript-specific checks could be added here
return result
async def _validate_generic(self, code: str, language: str) -> ValidationResult:
"""Generic validation for unsupported languages"""
errors = []
warnings = []
# Basic checks
if not code.strip():
errors.append("Empty code provided")
# Check for balanced braces (works for most C-style languages)
if language.lower() in ['java', 'c', 'cpp', 'c++', 'csharp', 'c#']:
if not self._check_balanced_braces(code):
errors.append(f"Unbalanced braces in {language} code")
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
language=language
)
def _check_security_patterns(self, code: str, language: str) -> List[str]:
"""Check for dangerous security patterns"""
issues = []
patterns = self.DANGEROUS_PATTERNS.get(language.lower(), [])
for pattern in patterns:
if re.search(pattern, code, re.IGNORECASE):
issues.append(f"Dangerous pattern detected: {pattern}")
return issues
def _check_common_issues(self, code: str) -> List[str]:
"""Check for common coding issues"""
warnings = []
# Check for hardcoded credentials
if re.search(r'password\s*=\s*["\'][^"\']+["\']', code, re.IGNORECASE):
warnings.append("Possible hardcoded password detected")
if re.search(r'api[_-]?key\s*=\s*["\'][^"\']+["\']', code, re.IGNORECASE):
warnings.append("Possible hardcoded API key detected")
# Check for TODO/FIXME comments
if re.search(r'(TODO|FIXME|HACK)', code, re.IGNORECASE):
warnings.append("Code contains TODO/FIXME comments")
return warnings
def _check_balanced_braces(self, code: str) -> bool:
"""Check if braces are balanced"""
stack = []
pairs = {'(': ')', '[': ']', '{': '}'}
for char in code:
if char in pairs:
stack.append(char)
elif char in pairs.values():
if not stack:
return False
if pairs[stack.pop()] != char:
return False
return len(stack) == 0
# Global validator instance
validator = CodeValidator()
async def validate_code_changes(changes: List[Dict[str, Any]]) -> Dict[str, ValidationResult]:
"""
Validate multiple code changes
Args:
changes: List of dicts with 'path', 'content', and optional 'language'
Returns:
Dict mapping file paths to validation results
"""
results = {}
for change in changes:
path = change.get('path', 'unknown')
content = change.get('content', '')
language = change.get('language')
# Infer language from file extension if not provided
if not language:
language = _infer_language_from_path(path)
result = await validator.validate_code(content, language, path)
results[path] = result
return results
def _infer_language_from_path(path: str) -> str:
"""Infer programming language from file path"""
ext = path.lower().split('.')[-1] if '.' in path else ''
language_map = {
'py': 'python',
'js': 'javascript',
'ts': 'typescript',
'java': 'java',
'cpp': 'cpp',
'c': 'c',
'cs': 'csharp',
'php': 'php',
'rb': 'ruby',
'go': 'go',
'rs': 'rust'
}
return language_map.get(ext, 'unknown')