ai_script_analyzer.py•22.5 kB
"""
AI Script Analyzer
Parses Python scripts generated by AI coding assistants using AST to extract:
- Import statements and their usage
- Class instantiations and method calls
- Function calls with parameters
- Attribute access patterns
- Variable type tracking
"""
import ast
import logging
from pathlib import Path
from typing import Dict, List, Set, Any, Optional, Tuple
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ImportInfo:
"""Information about an import statement"""
module: str
name: str
alias: Optional[str] = None
is_from_import: bool = False
line_number: int = 0
@dataclass
class MethodCall:
"""Information about a method call"""
object_name: str
method_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
object_type: Optional[str] = None # Inferred class type
@dataclass
class AttributeAccess:
"""Information about attribute access"""
object_name: str
attribute_name: str
line_number: int
object_type: Optional[str] = None # Inferred class type
@dataclass
class FunctionCall:
"""Information about a function call"""
function_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
full_name: Optional[str] = None # Module.function_name
@dataclass
class ClassInstantiation:
"""Information about class instantiation"""
variable_name: str
class_name: str
args: List[str]
kwargs: Dict[str, str]
line_number: int
full_class_name: Optional[str] = None # Module.ClassName
@dataclass
class AnalysisResult:
"""Complete analysis results for a Python script"""
file_path: str
imports: List[ImportInfo] = field(default_factory=list)
class_instantiations: List[ClassInstantiation] = field(default_factory=list)
method_calls: List[MethodCall] = field(default_factory=list)
attribute_accesses: List[AttributeAccess] = field(default_factory=list)
function_calls: List[FunctionCall] = field(default_factory=list)
variable_types: Dict[str, str] = field(default_factory=dict) # variable_name -> class_type
errors: List[str] = field(default_factory=list)
class AIScriptAnalyzer:
"""Analyzes AI-generated Python scripts for validation against knowledge graph"""
def __init__(self):
self.import_map: Dict[str, str] = {} # alias -> actual_module_name
self.variable_types: Dict[str, str] = {} # variable_name -> class_type
self.context_manager_vars: Dict[str, Tuple[int, int, str]] = {} # var_name -> (start_line, end_line, type)
def analyze_script(self, script_path: str) -> AnalysisResult:
"""Analyze a Python script and extract all relevant information"""
try:
with open(script_path, 'r', encoding='utf-8') as f:
content = f.read()
tree = ast.parse(content)
result = AnalysisResult(file_path=script_path)
# Reset state for new analysis
self.import_map.clear()
self.variable_types.clear()
self.context_manager_vars.clear()
# Track processed nodes to avoid duplicates
self.processed_calls = set()
self.method_call_attributes = set()
# First pass: collect imports and build import map
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._extract_imports(node, result)
# Second pass: analyze usage patterns
for node in ast.walk(tree):
self._analyze_node(node, result)
# Set inferred types on method calls and attribute accesses
self._infer_object_types(result)
result.variable_types = self.variable_types.copy()
return result
except Exception as e:
error_msg = f"Failed to analyze script {script_path}: {str(e)}"
logger.error(error_msg)
result = AnalysisResult(file_path=script_path)
result.errors.append(error_msg)
return result
def _extract_imports(self, node: ast.AST, result: AnalysisResult):
"""Extract import information and build import mapping"""
line_num = getattr(node, 'lineno', 0)
if isinstance(node, ast.Import):
for alias in node.names:
import_name = alias.name
alias_name = alias.asname or import_name
result.imports.append(ImportInfo(
module=import_name,
name=import_name,
alias=alias.asname,
is_from_import=False,
line_number=line_num
))
self.import_map[alias_name] = import_name
elif isinstance(node, ast.ImportFrom):
module = node.module or ""
for alias in node.names:
import_name = alias.name
alias_name = alias.asname or import_name
result.imports.append(ImportInfo(
module=module,
name=import_name,
alias=alias.asname,
is_from_import=True,
line_number=line_num
))
# Map alias to full module.name
if module:
full_name = f"{module}.{import_name}"
self.import_map[alias_name] = full_name
else:
self.import_map[alias_name] = import_name
def _analyze_node(self, node: ast.AST, result: AnalysisResult):
"""Analyze individual AST nodes for usage patterns"""
line_num = getattr(node, 'lineno', 0)
# Assignments (class instantiations and method call results)
if isinstance(node, ast.Assign):
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
if isinstance(node.value, ast.Call):
# Check if it's a class instantiation or method call
if isinstance(node.value.func, ast.Name):
# Direct function/class call
self._extract_class_instantiation(node, result)
# Mark this call as processed to avoid duplicate processing
self.processed_calls.add(id(node.value))
elif isinstance(node.value.func, ast.Attribute):
# Method call - track the variable assignment for type inference
var_name = node.targets[0].id
self._track_method_result_assignment(node.value, var_name)
# Still process the method call
self._extract_method_call(node.value, result)
self.processed_calls.add(id(node.value))
# AsyncWith statements (context managers)
elif isinstance(node, ast.AsyncWith):
self._handle_async_with(node, result)
elif isinstance(node, ast.With):
self._handle_with(node, result)
# Method calls and function calls
elif isinstance(node, ast.Call):
# Skip if this call was already processed as part of an assignment
if id(node) in self.processed_calls:
return
if isinstance(node.func, ast.Attribute):
self._extract_method_call(node, result)
# Mark this attribute as used in method call to avoid duplicate processing
self.method_call_attributes.add(id(node.func))
elif isinstance(node.func, ast.Name):
# Check if this is likely a class instantiation (based on imported classes)
func_name = node.func.id
full_name = self._resolve_full_name(func_name)
# If this is a known imported class, treat as class instantiation
if self._is_likely_class_instantiation(func_name, full_name):
self._extract_nested_class_instantiation(node, result)
else:
self._extract_function_call(node, result)
# Attribute access (not in call context)
elif isinstance(node, ast.Attribute):
# Skip if this attribute was already processed as part of a method call
if id(node) in self.method_call_attributes:
return
self._extract_attribute_access(node, result)
def _extract_class_instantiation(self, node: ast.Assign, result: AnalysisResult):
"""Extract class instantiation from assignment"""
target = node.targets[0]
call = node.value
line_num = getattr(node, 'lineno', 0)
if isinstance(target, ast.Name) and isinstance(call, ast.Call):
var_name = target.id
class_name = self._get_name_from_call(call.func)
if class_name:
args = [self._get_arg_representation(arg) for arg in call.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in call.keywords if kw.arg
}
# Resolve full class name using import map
full_class_name = self._resolve_full_name(class_name)
instantiation = ClassInstantiation(
variable_name=var_name,
class_name=class_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_class_name=full_class_name
)
result.class_instantiations.append(instantiation)
# Track variable type for later method call analysis
self.variable_types[var_name] = full_class_name or class_name
def _extract_method_call(self, node: ast.Call, result: AnalysisResult):
"""Extract method call information"""
if isinstance(node.func, ast.Attribute):
line_num = getattr(node, 'lineno', 0)
# Get object and method names
obj_name = self._get_name_from_node(node.func.value)
method_name = node.func.attr
if obj_name and method_name:
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
method_call = MethodCall(
object_name=obj_name,
method_name=method_name,
args=args,
kwargs=kwargs,
line_number=line_num,
object_type=self.variable_types.get(obj_name)
)
result.method_calls.append(method_call)
def _extract_function_call(self, node: ast.Call, result: AnalysisResult):
"""Extract function call information"""
if isinstance(node.func, ast.Name):
line_num = getattr(node, 'lineno', 0)
func_name = node.func.id
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
# Resolve full function name using import map
full_func_name = self._resolve_full_name(func_name)
function_call = FunctionCall(
function_name=func_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_name=full_func_name
)
result.function_calls.append(function_call)
def _extract_attribute_access(self, node: ast.Attribute, result: AnalysisResult):
"""Extract attribute access information"""
line_num = getattr(node, 'lineno', 0)
obj_name = self._get_name_from_node(node.value)
attr_name = node.attr
if obj_name and attr_name:
attribute_access = AttributeAccess(
object_name=obj_name,
attribute_name=attr_name,
line_number=line_num,
object_type=self.variable_types.get(obj_name)
)
result.attribute_accesses.append(attribute_access)
def _infer_object_types(self, result: AnalysisResult):
"""Update object types for method calls and attribute accesses"""
for method_call in result.method_calls:
if not method_call.object_type:
# First check context manager variables
obj_type = self._get_context_aware_type(method_call.object_name, method_call.line_number)
if obj_type:
method_call.object_type = obj_type
else:
method_call.object_type = self.variable_types.get(method_call.object_name)
for attr_access in result.attribute_accesses:
if not attr_access.object_type:
# First check context manager variables
obj_type = self._get_context_aware_type(attr_access.object_name, attr_access.line_number)
if obj_type:
attr_access.object_type = obj_type
else:
attr_access.object_type = self.variable_types.get(attr_access.object_name)
def _get_context_aware_type(self, var_name: str, line_number: int) -> Optional[str]:
"""Get the type of a variable considering its context (e.g., async with scope)"""
if var_name in self.context_manager_vars:
start_line, end_line, var_type = self.context_manager_vars[var_name]
if start_line <= line_number <= end_line:
return var_type
return None
def _get_name_from_call(self, node: ast.AST) -> Optional[str]:
"""Get the name from a call node (for class instantiation)"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
value_name = self._get_name_from_node(node.value)
if value_name:
return f"{value_name}.{node.attr}"
return None
def _get_name_from_node(self, node: ast.AST) -> Optional[str]:
"""Get string representation of a node (for object names)"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
value_name = self._get_name_from_node(node.value)
if value_name:
return f"{value_name}.{node.attr}"
return None
def _get_arg_representation(self, node: ast.AST) -> str:
"""Get string representation of an argument"""
if isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return self._get_name_from_node(node) or "<?>"
elif isinstance(node, ast.Call):
func_name = self._get_name_from_call(node.func)
return f"{func_name}(...)" if func_name else "call(...)"
else:
return f"<{type(node).__name__}>"
def _is_likely_class_instantiation(self, func_name: str, full_name: Optional[str]) -> bool:
"""Determine if a function call is likely a class instantiation"""
# Check if it's a known imported class (classes typically start with uppercase)
if func_name and func_name[0].isupper():
return True
# Check if the full name suggests a class (contains known class patterns)
if full_name:
# Common class patterns in module names
class_patterns = [
'Model', 'Provider', 'Client', 'Agent', 'Manager', 'Handler',
'Builder', 'Factory', 'Service', 'Controller', 'Processor'
]
return any(pattern in full_name for pattern in class_patterns)
return False
def _extract_nested_class_instantiation(self, node: ast.Call, result: AnalysisResult):
"""Extract class instantiation that's not in direct assignment (e.g., as parameter)"""
line_num = getattr(node, 'lineno', 0)
if isinstance(node.func, ast.Name):
class_name = node.func.id
args = [self._get_arg_representation(arg) for arg in node.args]
kwargs = {
kw.arg: self._get_arg_representation(kw.value)
for kw in node.keywords if kw.arg
}
# Resolve full class name using import map
full_class_name = self._resolve_full_name(class_name)
# Use a synthetic variable name since this isn't assigned to a variable
var_name = f"<{class_name.lower()}_instance>"
instantiation = ClassInstantiation(
variable_name=var_name,
class_name=class_name,
args=args,
kwargs=kwargs,
line_number=line_num,
full_class_name=full_class_name
)
result.class_instantiations.append(instantiation)
def _track_method_result_assignment(self, call_node: ast.Call, var_name: str):
"""Track when a variable is assigned the result of a method call"""
if isinstance(call_node.func, ast.Attribute):
# For now, we'll use a generic type hint for method results
# In a more sophisticated system, we could look up the return type
self.variable_types[var_name] = "method_result"
def _handle_async_with(self, node: ast.AsyncWith, result: AnalysisResult):
"""Handle async with statements and track context manager variables"""
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
var_name = item.optional_vars.id
# If the context manager is a method call, track the result type
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
# Extract and process the method call
self._extract_method_call(item.context_expr, result)
self.processed_calls.add(id(item.context_expr))
# Track context manager scope for pydantic_ai run_stream calls
obj_name = self._get_name_from_node(item.context_expr.func.value)
method_name = item.context_expr.func.attr
if (obj_name and obj_name in self.variable_types and
'pydantic_ai' in str(self.variable_types[obj_name]) and
method_name == 'run_stream'):
# Calculate the scope of this async with block
start_line = getattr(node, 'lineno', 0)
end_line = getattr(node, 'end_lineno', start_line + 50) # fallback estimate
# For run_stream, the return type is specifically StreamedRunResult
# This is the actual return type, not a generic placeholder
self.context_manager_vars[var_name] = (start_line, end_line, "pydantic_ai.StreamedRunResult")
def _handle_with(self, node: ast.With, result: AnalysisResult):
"""Handle regular with statements and track context manager variables"""
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
var_name = item.optional_vars.id
# If the context manager is a method call, track the result type
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute):
# Extract and process the method call
self._extract_method_call(item.context_expr, result)
self.processed_calls.add(id(item.context_expr))
# Track basic type information
self.variable_types[var_name] = "context_manager_result"
def _resolve_full_name(self, name: str) -> Optional[str]:
"""Resolve a name to its full module.name using import map"""
# Check if it's a direct import mapping
if name in self.import_map:
return self.import_map[name]
# Check if it's a dotted name with first part in import map
parts = name.split('.')
if len(parts) > 1 and parts[0] in self.import_map:
base_module = self.import_map[parts[0]]
return f"{base_module}.{'.'.join(parts[1:])}"
return None
def analyze_ai_script(script_path: str) -> AnalysisResult:
"""Convenience function to analyze a single AI-generated script"""
analyzer = AIScriptAnalyzer()
return analyzer.analyze_script(script_path)
if __name__ == "__main__":
# Example usage
import sys
if len(sys.argv) != 2:
print("Usage: python ai_script_analyzer.py <script_path>")
sys.exit(1)
script_path = sys.argv[1]
result = analyze_ai_script(script_path)
print(f"Analysis Results for: {result.file_path}")
print(f"Imports: {len(result.imports)}")
print(f"Class Instantiations: {len(result.class_instantiations)}")
print(f"Method Calls: {len(result.method_calls)}")
print(f"Function Calls: {len(result.function_calls)}")
print(f"Attribute Accesses: {len(result.attribute_accesses)}")
if result.errors:
print(f"Errors: {result.errors}")