Skip to main content
Glama

CodeGraphContext

graph_builder.py58.2 kB
# src/codegraphcontext/tools/graph_builder.py import asyncio import ast import logging import os from pathlib import Path from typing import Any, Coroutine, Dict, Optional, Tuple from datetime import datetime from ..core.database import DatabaseManager from ..core.jobs import JobManager, JobStatus from ..utils.debug_log import debug_log logger = logging.getLogger(__name__) # This is for developers and testers only. It enables detailed debug logging to a file. # Set to 1 to enable, 0 to disable. debug_mode = 0 class CyclomaticComplexityVisitor(ast.NodeVisitor): """Calculates cyclomatic complexity for a given AST node.""" def __init__(self): self.complexity = 1 def visit_If(self, node): self.complexity += 1 self.generic_visit(node) def visit_For(self, node): self.complexity += 1 self.generic_visit(node) def visit_While(self, node): self.complexity += 1 self.generic_visit(node) def visit_With(self, node): self.complexity += len(node.items) self.generic_visit(node) def visit_AsyncFor(self, node): self.complexity += 1 self.generic_visit(node) def visit_AsyncWith(self, node): self.complexity += len(node.items) self.generic_visit(node) def visit_ExceptHandler(self, node): self.complexity += 1 self.generic_visit(node) def visit_BoolOp(self, node): self.complexity += len(node.values) - 1 self.generic_visit(node) def visit_ListComp(self, node): self.complexity += len(node.generators) self.generic_visit(node) def visit_SetComp(self, node): self.complexity += len(node.generators) self.generic_visit(node) def visit_DictComp(self, node): self.complexity += len(node.generators) self.generic_visit(node) def visit_GeneratorExp(self, node): self.complexity += len(node.generators) self.generic_visit(node) def visit_IfExp(self, node): self.complexity += 1 self.generic_visit(node) def visit_match_case(self, node): self.complexity += 1 self.generic_visit(node) class CodeVisitor(ast.NodeVisitor): """ The final, definitive, stateful AST visitor. It correctly maintains class-level, local-level, and module-level symbol tables to resolve complex calls. """ def __init__(self, file_path: str, imports_map: dict, is_dependency: bool = False): self.file_path = file_path self.is_dependency = is_dependency self.imports_map = imports_map self.functions, self.classes, self.variables, self.imports, self.function_calls = [], [], [], [], [] self.context_stack, self.current_context, self.current_class = [], None, None # Stateful Symbol Tables self.local_symbol_table = {} self.class_symbol_table = {} self.module_symbol_table = {} def _push_context(self, name, node_type, line_number): self.context_stack.append({"name": name, "type": node_type, "line_number": line_number, "previous_context": self.current_context, "previous_class": self.current_class}) self.current_context = name if node_type == "class": self.current_class = name def _pop_context(self): if self.context_stack: prev = self.context_stack.pop() self.current_context, self.current_class = prev["previous_context"], prev["previous_class"] def get_return_type_from_ast(self, file_path, class_name, method_name): if not file_path or not Path(file_path).exists(): return None with open(file_path, 'r', encoding='utf-8') as source_file: try: tree = ast.parse(source_file.read()) except (SyntaxError, ValueError): return None for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: for method_node in node.body: if isinstance(method_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and method_node.name == method_name: # Case 1: The method has an explicit return type hint if method_node.returns: # Unparse and strip quotes to handle forward references like "'PublicKey'" return ast.unparse(method_node.returns).strip("'\"") # Create a mini symbol table for the scope of this method local_assignments = {} for body_item in method_node.body: if (isinstance(body_item, ast.Assign) and isinstance(body_item.value, ast.Call) and isinstance(body_item.value.func, ast.Name) and isinstance(body_item.targets[0], ast.Name)): variable_name = body_item.targets[0].id class_name_assigned = body_item.value.func.id local_assignments[variable_name] = class_name_assigned # Now, check the return statements from the bottom up for body_item in reversed(method_node.body): if isinstance(body_item, ast.Return): # Case 2: It returns a direct instantiation (e.g., return MyClass()) if (isinstance(body_item.value, ast.Call) and isinstance(body_item.value.func, ast.Name)): return body_item.value.func.id # Case 3: It returns a variable that was assigned earlier (e.g., return my_var) if (isinstance(body_item.value, ast.Name) and body_item.value.id in local_assignments): return local_assignments[body_item.value.id] return None def _resolve_type_from_call(self, node: ast.Call): if not isinstance(node.func, ast.Attribute): return None # Case 1: Base of the call is a simple name, e.g., `var.method()` if isinstance(node.func.value, ast.Name): obj_name = node.func.value.id method_name = node.func.attr # Check local, then class, then module scopes obj_type = (self.local_symbol_table.get(obj_name) or self.class_symbol_table.get(obj_name) or self.module_symbol_table.get(obj_name)) if obj_type: paths = self.imports_map.get(obj_type, []) if paths: return self.get_return_type_from_ast(paths[0], obj_type, method_name) # Case 2: Base of the call is another call (a chain), e.g., `var.method1().method2()` elif isinstance(node.func.value, ast.Call): intermediate_type = self._resolve_type_from_call(node.func.value) if intermediate_type: method_name = node.func.attr paths = self.imports_map.get(intermediate_type, []) if paths: return self.get_return_type_from_ast(paths[0], intermediate_type, method_name) return None def visit_ClassDef(self, node): self.class_symbol_table = {} class_data = {"name": node.name, "line_number": node.lineno, "end_line": getattr(node, 'end_lineno', None), "bases": [ast.unparse(b) for b in node.bases], "source": ast.unparse(node), "context": self.current_context, "is_dependency": self.is_dependency, "docstring": ast.get_docstring(node), "decorators": [ast.unparse(d) for d in node.decorator_list]} self.classes.append(class_data) self._push_context(node.name, "class", node.lineno) # Pre-pass to populate class symbol table from __init__ or setUp for method_node in node.body: if isinstance(method_node, ast.FunctionDef) and method_node.name in ('__init__', 'setUp'): self._handle_constructor_assignments(method_node) # Visit all children of the class now self.generic_visit(node) self._pop_context() self.class_symbol_table = {} def _handle_constructor_assignments(self, constructor_node: ast.FunctionDef): """ Infers types for class attributes assigned from constructor arguments. This fixes the `self.job_manager = job_manager` case. """ # Get a map of argument names to their type hints (as strings) arg_types = { arg.arg: ast.unparse(arg.annotation) for arg in constructor_node.args.args if arg.annotation } # Scan the body of the constructor for assignments for body_node in ast.walk(constructor_node): if isinstance(body_node, ast.Assign): # Check for assignments like `self.attr = arg` if ( isinstance(body_node.targets[0], ast.Attribute) and isinstance(body_node.targets[0].value, ast.Name) and body_node.targets[0].value.id == "self" and isinstance(body_node.value, ast.Name) ): attr_name = body_node.targets[0].attr arg_name = body_node.value.id if arg_name in arg_types: # We found a match! Infer the type and add it to the symbol table. self.class_symbol_table[attr_name] = arg_types[arg_name] debug_log(f"Inferred type for self.{attr_name}: {arg_types[arg_name]}") def visit_FunctionDef(self, node): # The class pre-pass will handle setUp/__init__, so we reset the local table here if node.name not in ('__init__', 'setUp'): self.local_symbol_table = {} func_data = {"name": node.name, "line_number": node.lineno, "end_line": getattr(node, 'end_lineno', None), "args": [arg.arg for arg in node.args.args], "source": ast.unparse(node), "context": self.current_context, "class_context": self.current_class, "is_dependency": self.is_dependency, "docstring": ast.get_docstring(node), "decorators": [ast.unparse(d) for d in node.decorator_list], "source_code": ast.unparse(node)} # Add source_code here self.functions.append(func_data) self.functions.append(func_data) self._push_context(node.name, "function", node.lineno) # This will trigger visit_Assign and visit_Call for nodes inside the function self.generic_visit(node) self._pop_context() def visit_Assign(self, node): assigned_type = None # Manual check for nested calls to ensure they are processed if isinstance(node.value, ast.Call): # Now determine the type for the assignment target if isinstance(node.value.func, ast.Name): assigned_type = node.value.func.id elif isinstance(node.value.func, ast.Attribute): assigned_type = self._resolve_type_from_call(node.value) # If the main resolver fails, it's likely a class/static method call # like `addr = P2shAddress.from_script(...)`. We apply a heuristic: # assume the method returns an instance of its own class. if not assigned_type and isinstance(node.value.func.value, ast.Name): # `node.value.func.value.id` will be 'P2shAddress' in this case class_name = node.value.func.value.id # We can add a check to be safer: does this name correspond to an import? # This check makes the heuristic much more reliable. if class_name in self.imports_map: assigned_type = class_name # Handle assignments from a different variable `var = another_var` elif isinstance(node.value, ast.Name): assigned_type = (self.local_symbol_table.get(node.value.id) or self.class_symbol_table.get(node.value.id) or self.module_symbol_table.get(node.value.id)) if assigned_type and isinstance(assigned_type, str): assigned_type = assigned_type.strip("'\"") # Part 1: Populate symbol tables correctly if assigned_type: for target in node.targets: if isinstance(target, ast.Attribute) and hasattr(target.value, 'id') and target.value.id == 'self': self.class_symbol_table[target.attr] = assigned_type elif isinstance(target, ast.Name): if self.current_context is None and self.current_class is None: # This is a top-level assignment self.module_symbol_table[target.id] = assigned_type else: # This is a local assignment self.local_symbol_table[target.id] = assigned_type # Part 2: Collect variable data for the graph for target in node.targets: # The key change is here: check for both simple names AND 'self.attribute' if isinstance(target, ast.Name): var_name = target.id elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == 'self': var_name = f"self.{target.attr}" else: continue # Currently skips other types of assignments like tuple unpacking var_data = { "name": var_name, "line_number": node.lineno, "value": ast.unparse(node.value) if hasattr(ast, "unparse") else "", "context": self.current_context, "class_context": self.current_class, "is_dependency": self.is_dependency, } self.variables.append(var_data) # Now call generic_visit to ensure nested nodes (like calls) are processed self.generic_visit(node) def visit_AsyncFunctionDef(self, node): """Visit async function definitions""" self.visit_FunctionDef(node) def visit_AnnAssign(self, node): """Visit annotated assignments (type hints)""" if isinstance(node.target, ast.Name): var_data = { "name": node.target.id, "line_number": node.lineno, "value": ( ast.unparse(node.value) if node.value and hasattr(ast, "unparse") else "" ), "context": self.current_context, "class_context": self.current_class, "is_dependency": self.is_dependency, } self.variables.append(var_data) self.generic_visit(node) def visit_Import(self, node): """Visit import statements""" for name in node.names: import_data = { "name": name.name.split('.')[0], # Store the top-level package name "full_import_name": name.name, # Store the full import name "line_number": node.lineno, "alias": name.asname, "context": self.current_context, "is_dependency": self.is_dependency, } self.imports.append(import_data) self.generic_visit(node) def visit_ImportFrom(self, node): """ Visit from-import statements, now correctly capturing the relative path. """ # Create the relative path prefix (e.g., '.', '..') based on the level. prefix = '.' * node.level for alias in node.names: # If node.module is None, it's an import like `from . import name` # Determine the base module name for the 'name' property if node.module: # For 'from .module import name', base_module is 'module' # For 'from package.module import name', base_module is 'package' base_module = node.module.split('.')[0] full_import_name = f"{prefix}{node.module}.{alias.name}" else: # For 'from . import name', base_module is 'name' base_module = alias.name full_import_name = f"{prefix}{alias.name}" import_data = { "name": base_module, # Store the top-level module name "full_import_name": full_import_name, # Store the full import path "line_number": node.lineno, "alias": alias.asname, "context": self.current_context, "is_dependency": self.is_dependency, } self.imports.append(import_data) def _resolve_attribute_base_type(self, node: ast.Attribute) -> Optional[str]: """ Recursively traverses an attribute chain (e.g., self.manager.db) to find the type of the final attribute. """ base_node = node.value current_type = None # Step 1: Find the type of the initial object in the chain if isinstance(base_node, ast.Name): obj_name = base_node.id if obj_name == 'self': current_type = self.current_class else: # Check local, then class, then module scopes current_type = (self.local_symbol_table.get(obj_name) or self.class_symbol_table.get(obj_name) or self.module_symbol_table.get(obj_name)) elif isinstance(base_node, ast.Call): current_type = self._resolve_type_from_call(base_node) # You can keep your existing call resolver elif isinstance(base_node, ast.Attribute): # It's a nested attribute, recurse! e.g., self.a.b current_type = self._resolve_attribute_base_type(base_node) # Step 2: If we found the base type, now find the type of the final attribute if current_type: paths = self.imports_map.get(current_type, []) if paths: # This is a simplification; a better implementation would need to # parse the class file to find the type of the 'attr' # For now, I assume a direct method call on the found type return_type = self.get_return_type_from_ast(paths[0], current_type, node.attr) # If get_return_type_from_ast finds a return type, that's our new type. # If not, we can assume the attribute itself is of a certain type # This part is complex and may require parsing the class file for assignments. # For the current problem, just knowing the `current_type` is enough. # Let's modify the goal to return the type of the object *containing* the method. return current_type # Return the type of the object, e.g., 'JobManager' return None def visit_Call(self, node): """Visit function calls with enhanced detection""" call_name = None full_call_name = None try: full_call_name = ast.unparse(node.func) if isinstance(node.func, ast.Name): call_name = node.func.id elif isinstance(node.func, ast.Attribute): call_name = node.func.attr else: call_name = full_call_name except Exception: self.generic_visit(node) return try: call_args = [ast.unparse(arg) for arg in node.args] except Exception: call_args = [] inferred_obj_type = None if isinstance(node.func, ast.Attribute): base_obj_node = node.func.value if isinstance(base_obj_node, ast.Name): obj_name = base_obj_node.id if obj_name == 'self': # If the base is 'self', find the type of the attribute on the current class inferred_obj_type = self.class_symbol_table.get(node.func.attr) if not inferred_obj_type: # Fallback for method calls directly on self inferred_obj_type = self.current_class else: inferred_obj_type = (self.local_symbol_table.get(obj_name) or self.class_symbol_table.get(obj_name) or self.module_symbol_table.get(obj_name)) # If it's not a variable, it might be a direct call on a Class name. if not inferred_obj_type and obj_name in self.imports_map: inferred_obj_type = obj_name elif isinstance(base_obj_node, ast.Call): inferred_obj_type = self._resolve_type_from_call(base_obj_node) elif isinstance(base_obj_node, ast.Attribute): # e.g., self.job_manager # This handles nested attributes # The goal is to find the type of `self.job_manager`, which is 'JobManager' # Resolve the base of the chain, e.g., get 'self' from 'self.job_manager' base = base_obj_node while isinstance(base, ast.Attribute): base = base.value if isinstance(base, ast.Name) and base.id == 'self': # In self.X.Y... The attribute we care about is the first one, X attr_name = base_obj_node.attr inferred_obj_type = self.class_symbol_table.get(attr_name) elif isinstance(node.func, ast.Name): inferred_obj_type = (self.local_symbol_table.get(call_name) or self.class_symbol_table.get(call_name) or self.module_symbol_table.get(call_name)) # there are no CALLS relationships originating from P2pkhAddress.to_address in the graph. This is the root cause of the find_all_callees tool reporting 0 # results. # The problem is not with the find_all_callees query itself, but with the GraphBuilder's ability to correctly identify and create CALLS relationships for methods like # P2pkhAddress.to_address. # Specifically, the GraphBuilder._create_function_calls method is likely not correctly processing calls made within methods of a class, especially when those calls are to: # 1. self.method(): Internal method calls. # 2. Functions imported from other modules (e.g., h_to_b, get_network). # 3. Functions from external libraries (e.g., hashlib.sha256, b58encode). # The GraphBuilder.CodeVisitor.visit_Call method is responsible for identifying function calls. It needs to be improved to handle these cases. # Plan: # 1. Enhance `CodeVisitor.visit_Call` in `src/codegraphcontext/tools/graph_builder.py`: # * Internal Method Calls (`self.method()`): When node.func is an ast.Attribute and node.func.value.id is self, the call_name should be node.func.attr, and the resolved_path should # be the file_path of the current class. # * Imported Functions: The _create_function_calls method already has some logic for resolving imported functions using imports_map. I need to ensure this logic is robust and # correctly applied within visit_Call to set inferred_obj_type or resolved_path accurately. # * External Library Functions: For now, we might not be able to fully resolve calls to external library functions unless those libraries are also indexed. However, we should at # least capture the full_call_name and call_name for these. # inferred_obj_type = None # if isinstance(node.func, ast.Attribute): # base_obj_node = node.func.value # if isinstance(base_obj_node, ast.Name): # obj_name = base_obj_node.id # if obj_name == 'self': # # If the base is 'self', the call is to a method of the current class # inferred_obj_type = self.current_class # else: # # Try to resolve the type of the object from symbol tables # inferred_obj_type = (self.local_symbol_table.get(obj_name) or # self.class_symbol_table.get(obj_name) or # self.module_symbol_table.get(obj_name)) # # If not found in symbol tables, check if it's a class name from imports # if not inferred_obj_type and obj_name in self.imports_map: # inferred_obj_type = obj_name # elif isinstance(base_obj_node, ast.Call): # inferred_obj_type = self._resolve_type_from_call(base_obj_node) # elif isinstance(base_obj_node, ast.Attribute): # e.g., self.job_manager.method() # # Recursively resolve the type of the base attribute # inferred_obj_type = self._resolve_attribute_base_type(base_obj_node) # elif isinstance(node.func, ast.Name): # # If it's a direct function call, try to infer its type from symbol tables or imports # inferred_obj_type = (self.local_symbol_table.get(call_name) or # self.class_symbol_table.get(call_name) or # self.module_symbol_table.get(call_name)) # if not inferred_obj_type and call_name in self.imports_map: # inferred_obj_type = call_name if call_name and call_name not in __builtins__: call_data = { "name": call_name, "full_name": full_call_name, "line_number": node.lineno, "args": call_args, "inferred_obj_type": inferred_obj_type, "context": self.current_context, "class_context": self.current_class, "is_dependency": self.is_dependency, } self.function_calls.append(call_data) self.generic_visit(node) class GraphBuilder: """Module for building and managing the Neo4j code graph.""" def __init__(self, db_manager: DatabaseManager, job_manager: JobManager, loop: asyncio.AbstractEventLoop): self.db_manager = db_manager self.job_manager = job_manager self.loop = loop # Store the main event loop self.driver = self.db_manager.get_driver() self.create_schema() def create_schema(self): """Create constraints and indexes in Neo4j.""" with self.driver.session() as session: try: session.run("CREATE CONSTRAINT repository_path IF NOT EXISTS FOR (r:Repository) REQUIRE r.path IS UNIQUE") session.run("CREATE CONSTRAINT file_path IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE") session.run("CREATE CONSTRAINT directory_path IF NOT EXISTS FOR (d:Directory) REQUIRE d.path IS UNIQUE") session.run("CREATE CONSTRAINT function_unique IF NOT EXISTS FOR (f:Function) REQUIRE (f.name, f.file_path, f.line_number) IS UNIQUE") session.run("CREATE CONSTRAINT class_unique IF NOT EXISTS FOR (c:Class) REQUIRE (c.name, c.file_path, c.line_number) IS UNIQUE") session.run("CREATE CONSTRAINT variable_unique IF NOT EXISTS FOR (v:Variable) REQUIRE (v.name, v.file_path, v.line_number) IS UNIQUE") session.run("CREATE CONSTRAINT module_name IF NOT EXISTS FOR (m:Module) REQUIRE m.name IS UNIQUE") # Create a full-text search index for code search session.run(""" CREATE FULLTEXT INDEX code_search_index IF NOT EXISTS FOR (n:Function|Class|Variable) ON EACH [n.name, n.source, n.docstring] """) logger.info("Database schema verified/created successfully") except Exception as e: logger.warning(f"Schema creation warning: {e}") def _pre_scan_for_imports(self, files: list[Path]) -> dict: """Scans all files to create a map of class/function names to a LIST of their file paths.""" imports_map = {} for file_path in files: try: with open(file_path, "r", encoding="utf-8") as f: tree = ast.parse(f.read()) for node in ast.walk(tree): if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): if node.name not in imports_map: imports_map[node.name] = [] imports_map[node.name].append(str(file_path.resolve())) except Exception as e: logger.warning(f"Pre-scan failed for {file_path}: {e}") return imports_map def add_repository_to_graph(self, repo_path: Path, is_dependency: bool = False): """Adds a repository node using its absolute path as the unique key.""" repo_name = repo_path.name repo_path_str = str(repo_path.resolve()) with self.driver.session() as session: session.run( """ MERGE (r:Repository {path: $path}) SET r.name = $name, r.is_dependency = $is_dependency """, path=repo_path_str, name=repo_name, is_dependency=is_dependency, ) def add_file_to_graph(self, file_data: Dict, repo_name: str, imports_map: dict): """Adds a file and its contents within a single, unified session.""" file_path_str = str(Path(file_data['file_path']).resolve()) file_name = Path(file_path_str).name is_dependency = file_data.get('is_dependency', False) with self.driver.session() as session: try: repo_result = session.run("MATCH (r:Repository {name: $repo_name}) RETURN r.path as path", repo_name=repo_name).single() relative_path = str(Path(file_path_str).relative_to(Path(repo_result['path']))) if repo_result else file_name except ValueError: relative_path = file_name # Create/Merge the file node session.run(""" MERGE (f:File {path: $path}) SET f.name = $name, f.relative_path = $relative_path, f.is_dependency = $is_dependency """, path=file_path_str, name=file_name, relative_path=relative_path, is_dependency=is_dependency) # Create directory structure and link it file_path_obj = Path(file_path_str) repo_path_obj = Path(repo_result['path']) relative_path_to_file = file_path_obj.relative_to(repo_path_obj) parent_path = str(repo_path_obj) parent_label = 'Repository' # Create nodes for each directory part of the path for part in relative_path_to_file.parts[:-1]: # For each directory in the path current_path = Path(parent_path) / part current_path_str = str(current_path) session.run(f""" MATCH (p:{parent_label} {{path: $parent_path}}) MERGE (d:Directory {{path: $current_path}}) SET d.name = $part MERGE (p)-[:CONTAINS]->(d) """, parent_path=parent_path, current_path=current_path_str, part=part) parent_path = current_path_str parent_label = 'Directory' # Link the last directory/repository to the file session.run(f""" MATCH (p:{parent_label} {{path: $parent_path}}) MATCH (f:File {{path: $file_path}}) MERGE (p)-[:CONTAINS]->(f) """, parent_path=parent_path, file_path=file_path_str) for item_data, label in [(file_data['functions'], 'Function'), (file_data['classes'], 'Class'), (file_data['variables'], 'Variable')]: for item in item_data: query = f""" MATCH (f:File {{path: $file_path}}) MERGE (n:{label} {{name: $name, file_path: $file_path, line_number: $line_number}}) SET n += $props MERGE (f)-[:CONTAINS]->(n) """ session.run(query, file_path=file_path_str, name=item['name'], line_number=item['line_number'], props=item) # If it's a function, create parameter nodes and relationships and calculate complexity if label == 'Function': # Calculate cyclomatic complexity try: func_tree = ast.parse(item['source_code']) complexity_visitor = CyclomaticComplexityVisitor() complexity_visitor.visit(func_tree) item['cyclomatic_complexity'] = complexity_visitor.complexity except Exception as e: logger.warning(f"Could not calculate cyclomatic complexity for {item['name']} in {file_path_str}: {e}") item['cyclomatic_complexity'] = 1 # Default to 1 on error for arg_name in item.get('args', []): session.run(""" MATCH (fn:Function {name: $func_name, file_path: $file_path, line_number: $line_number}) MERGE (p:Parameter {name: $arg_name, file_path: $file_path, function_line_number: $line_number}) MERGE (fn)-[:HAS_PARAMETER]->(p) """, func_name=item['name'], file_path=file_path_str, line_number=item['line_number'], arg_name=arg_name) for imp in file_data['imports']: set_clauses = ["m.alias = $alias"] if 'full_import_name' in imp: set_clauses.append("m.full_import_name = $full_import_name") set_clause_str = ", ".join(set_clauses) session.run(f""" MATCH (f:File {{path: $file_path}}) MERGE (m:Module {{name: $name}}) SET {set_clause_str} MERGE (f)-[:IMPORTS]->(m) """, file_path=file_path_str, **imp) for class_item in file_data.get('classes', []): if class_item.get('bases'): for base_class_name in class_item['bases']: resolved_parent_file_path = self._resolve_class_path( base_class_name, file_path_str, file_data['imports'], imports_map ) if resolved_parent_file_path: session.run(""" MATCH (child:Class {name: $child_name, file_path: $file_path}) MATCH (parent:Class {name: $parent_name, file_path: $resolved_parent_file_path}) MERGE (child)-[:INHERITS_FROM]->(parent) """, child_name=class_item['name'], file_path=file_path_str, parent_name=base_class_name, resolved_parent_file_path=resolved_parent_file_path) self._create_class_method_relationships(session, file_data) self._create_contextual_relationships(session, file_data) def _create_contextual_relationships(self, session, file_data: Dict): """Create CONTAINS relationships from functions/classes to their children.""" file_path = str(Path(file_data['file_path']).resolve()) for func in file_data.get('functions', []): if func.get('class_context'): session.run(""" MATCH (c:Class {name: $class_name, file_path: $file_path}) MATCH (fn:Function {name: $func_name, file_path: $file_path, line_number: $func_line}) MERGE (c)-[:CONTAINS]->(fn) """, class_name=func['class_context'], file_path=file_path, func_name=func['name'], func_line=func['line_number']) for var in file_data.get('variables', []): context = var.get('context') class_context = var.get('class_context') parent_line = var.get('parent_line') if class_context: session.run(""" MATCH (c:Class {name: $class_name, file_path: $file_path}) MATCH (v:Variable {name: $var_name, file_path: $file_path, line_number: $var_line}) MERGE (c)-[:CONTAINS]->(v) """, class_name=class_context, file_path=file_path, var_name=var['name'], var_line=var['line_number']) elif context and parent_line: parent_label = "Function" parent_node_data = None for class_data in file_data.get('classes', []): if class_data['name'] == context and class_data['line_number'] == parent_line: parent_label = "Class" parent_node_data = class_data break if not parent_node_data: for func_data in file_data.get('functions', []): if func_data['name'] == context and func_data['line_number'] == parent_line: parent_label = "Function" parent_node_data = func_data break if parent_node_data: session.run(f""" MATCH (p:{parent_label} {{name: $parent_name, file_path: $file_path, line_number: $parent_line}}) MATCH (v:Variable {{name: $var_name, file_path: $file_path, line_number: $var_line}}) MERGE (p)-[:CONTAINS]->(v) """, parent_name=context, file_path=file_path, parent_line=parent_line, var_name=var['name'], var_line=var['line_number']) else: session.run(""" MATCH (f:File {path: $file_path}) MATCH (v:Variable {name: $var_name, file_path: $file_path, line_number: $var_line}) MERGE (f)-[:CONTAINS]->(v) """, file_path=file_path, var_name=var['name'], var_line=var['line_number']) def _create_function_calls(self, session, file_data: Dict, imports_map: dict): """ Create CALLS relationships with a unified, prioritized logic flow for all call types. """ caller_file_path = str(Path(file_data['file_path']).resolve()) local_function_names = {func['name'] for func in file_data.get('functions', [])} local_imports = {imp['alias'] or imp['name'].split('.')[-1]: imp['name'] for imp in file_data.get('imports', [])} for call in file_data.get('function_calls', []): called_name = call['name'] if called_name in __builtins__: continue resolved_path = None # Priority 1: Handle method calls (var.method(), self.attr.method(), etc.) # This is the most specific and reliable information we have. if call.get('inferred_obj_type'): obj_type = call['inferred_obj_type'] possible_paths = imports_map.get(obj_type, []) if len(possible_paths) > 0: # Simplistic choice for now; assumes the first found definition is correct. resolved_path = possible_paths[0] # Priority 2: Handle direct calls (func()) and class methods (Class.method()) else: # For class methods, the `called_name` will be the class itself lookup_name = call['full_name'].split('.')[0] if '.' in call['full_name'] else called_name possible_paths = imports_map.get(lookup_name, []) # A) Is it a local function? if lookup_name in local_function_names: resolved_path = caller_file_path # B) Is it an unambiguous global function/class? elif len(possible_paths) == 1: resolved_path = possible_paths[0] # C) Is it an ambiguous call we can resolve via this file's imports? elif len(possible_paths) > 1 and lookup_name in local_imports: full_import_name = local_imports[lookup_name] for path in possible_paths: if full_import_name.replace('.', '/') in path: resolved_path = path break # Fallback if no path could be resolved by any of the above rules if not resolved_path: # If the called name is in the imports map, use its path if called_name in imports_map and imports_map[called_name]: resolved_path = imports_map[called_name][0] # Take the first path for now else: resolved_path = caller_file_path caller_context = call.get('context') inferred_type = call.get('inferred_obj_type') if debug_mode: log_inferred_str = f" (via inferred type {inferred_type})" if inferred_type else "" debug_log(f"Resolved call: {caller_context} @ {caller_file_path} calls {called_name} @ {resolved_path}{log_inferred_str}") if caller_context: session.run(""" MATCH (caller:Function {name: $caller_name, file_path: $caller_file_path}) MATCH (called:Function {name: $called_name, file_path: $called_file_path}) MERGE (caller)-[:CALLS {line_number: $line_number, args: $args, full_call_name: $full_call_name}]->(called) """, caller_name=caller_context, caller_file_path=caller_file_path, called_name=called_name, called_file_path=resolved_path, line_number=call['line_number'], args=call.get('args', []), full_call_name=call.get('full_name', called_name)) else: # Handle calls from the top-level of a file session.run(""" MATCH (caller:File {path: $caller_file_path}) MATCH (called:Function {name: $called_name, file_path: $called_file_path}) MERGE (caller)-[:CALLS {line_number: $line_number, args: $args, full_call_name: $full_call_name}]->(called) """, caller_file_path=caller_file_path, called_name=called_name, called_file_path=resolved_path, line_number=call['line_number'], args=call.get('args', []), full_call_name=call.get('full_name', called_name)) def _create_all_function_calls(self, all_file_data: list[Dict], imports_map: dict): """Create CALLS relationships for all functions after all files have been processed.""" with self.driver.session() as session: for file_data in all_file_data: self._create_function_calls(session, file_data, imports_map) def _create_class_method_relationships(self, session, file_data: Dict): """Create CONTAINS relationships from classes to their methods""" file_path = str(Path(file_data['file_path']).resolve()) for func in file_data.get('functions', []): class_context = func.get('class_context') if class_context: session.run(""" MATCH (c:Class {name: $class_name, file_path: $file_path}) MATCH (fn:Function {name: $func_name, file_path: $file_path, line_number: $func_line}) MERGE (c)-[:CONTAINS]->(fn) """, class_name=class_context, file_path=file_path, func_name=func['name'], func_line=func['line_number']) def _resolve_class_path(self, class_name: str, current_file_path: str, current_file_imports: list, global_imports_map: dict) -> Optional[str]: debug_log(f"_resolve_class_path: Resolving '{class_name}' from '{current_file_path}'") """ Resolves the file path of a class based on import resolution priority. 1. Same file definition 2. Imports within the current file (direct or aliased) 3. Global imports map (anywhere in the indexed project) """ # Priority 1: Same file definition with self.driver.session() as session: result = session.run(""" MATCH (c:Class {name: $class_name, file_path: $current_file_path}) RETURN c.file_path AS file_path """, class_name=class_name, current_file_path=current_file_path).single() if result: debug_log(f"_resolve_class_path: Priority 1 match: {result['file_path']}") return result['file_path'] # Priority 2: Imports within the current file with self.driver.session() as session: result = session.run(""" MATCH (f:File {path: $current_file_path})-[:IMPORTS]->(m:Module) OPTIONAL MATCH (m)-[:CONTAINS]->(c:Class {name: $class_name}) RETURN c.file_path AS file_path """, current_file_path=current_file_path, class_name=class_name).single() if result and result["file_path"]: debug_log(f"_resolve_class_path: Priority 2 match: {result['file_path']}") return result['file_path'] # Priority 3: Global imports map (anywhere in the indexed project) - Fallback if class_name in global_imports_map: debug_log(f"_resolve_class_path: Priority 3 match: {global_imports_map[class_name][0]}") return global_imports_map[class_name][0] debug_log(f"_resolve_class_path: No path resolved for '{class_name}'") return None def delete_file_from_graph(self, file_path: str): """Deletes a file and all its contained elements and relationships.""" file_path_str = str(Path(file_path).resolve()) with self.driver.session() as session: # Get parent directories parents_res = session.run(""" MATCH (f:File {path: $path})<-[:CONTAINS*]-(d:Directory) RETURN d.path as path ORDER BY d.path DESC """, path=file_path_str) parent_paths = [record["path"] for record in parents_res] # Delete the file and its contents session.run( """ MATCH (f:File {path: $path}) OPTIONAL MATCH (f)-[:CONTAINS]->(element) DETACH DELETE f, element """, path=file_path_str, ) logger.info(f"Deleted file and its elements from graph: {file_path_str}") # Clean up empty parent directories, starting from the deepest for path in parent_paths: session.run(""" MATCH (d:Directory {path: $path}) WHERE NOT (d)-[:CONTAINS]->() DETACH DELETE d """, path=path) def delete_repository_from_graph(self, repo_path: str): """ Deletes a repository and all its contents from the graph, then cleans up any orphaned Module nodes that are no longer referenced. """ repo_path_str = str(Path(repo_path).resolve()) with self.driver.session() as session: # Delete the repository and all its contained elements, including parameters session.run(""" MATCH (r:Repository {path: $path}) OPTIONAL MATCH (r)-[:CONTAINS*]->(e) OPTIONAL MATCH (e)-[:HAS_PARAMETER]->(p) DETACH DELETE r, e, p """, path=repo_path_str) logger.info(f"Deleted repository and its contents from graph: {repo_path_str}") # Clean up orphaned Module nodes that are no longer imported by any file session.run(""" MATCH (m:Module) WHERE NOT ()-[:IMPORTS]->(m) DETACH DELETE m """) logger.info("Cleaned up orphaned Module nodes.") def update_file_in_graph(self, file_path: Path, repo_path: Path, imports_map: dict): """ Updates a single file's nodes in the graph and returns its new parsed data. This function does NOT handle re-linking the call graph. """ file_path_str = str(file_path.resolve()) repo_name = repo_path.name # --- STEP 1: Delete the old file from the graph --- debug_log(f"[update_file_in_graph] Deleting old file data for: {file_path_str}") try: self.delete_file_from_graph(file_path_str) debug_log(f"[update_file_in_graph] Old file data deleted for: {file_path_str}") except Exception as e: logger.error(f"Error deleting old file data for {file_path_str}: {e}") return None # Return None on failure # --- STEP 2: Re-parse and re-add the new file --- if file_path.exists(): debug_log(f"[update_file_in_graph] Parsing new file data for: {file_path_str}") # Pass imports_map to the parser file_data = self.parse_python_file(repo_path, file_path, imports_map) if "error" not in file_data: debug_log(f"[update_file_in_graph] Adding new file data to graph for: {file_path_str}") self.add_file_to_graph(file_data, repo_name, imports_map) debug_log(f"[update_file_in_graph] New file data added for: {file_path_str}") # --- CRITICAL: Return the new data --- return file_data else: logger.error(f"Skipping graph add for {file_path_str} due to parsing error: {file_data['error']}") return None # Return None on failure else: debug_log(f"[update_file_in_graph] File no longer exists: {file_path_str}") # Return a special marker for deleted files return {"deleted": True, "path": file_path_str} def parse_python_file(self, repo_path: Path, file_path: Path, imports_map: dict, is_dependency: bool = False) -> Dict: """Parse a Python file and extract code elements""" debug_log(f"[parse_python_file] Starting parsing for: {file_path}") try: with open(file_path, "r", encoding="utf-8") as f: source_code = f.read() tree = ast.parse(source_code) visitor = CodeVisitor(str(file_path), imports_map, is_dependency) visitor.visit(tree) if debug_mode: debug_log(f"[parse_python_file] Successfully parsed: {file_path}") return { "repo_path": str(repo_path), "file_path": str(file_path), "functions": visitor.functions, "classes": visitor.classes, "variables": visitor.variables, "imports": visitor.imports, "function_calls": visitor.function_calls, "is_dependency": is_dependency, } except Exception as e: logger.error(f"Error parsing {file_path}: {e}") debug_log(f"[parse_python_file] Error parsing {file_path}: {e}") return {"file_path": str(file_path), "error": str(e)} def estimate_processing_time(self, path: Path) -> Optional[Tuple[int, float]]: """Estimate processing time and file count""" try: if path.is_file(): files = [path] else: files = list(path.rglob("*.py")) total_files = len(files) # Simple heuristic: 0.1 seconds per file estimated_time = total_files * 0.1 return total_files, estimated_time except Exception as e: logger.error(f"Could not estimate processing time for {path}: {e}") return None async def build_graph_from_path_async( self, path: Path, is_dependency: bool = False, job_id: str = None ): """Builds graph from a directory or file path.""" try: if job_id: self.job_manager.update_job(job_id, status=JobStatus.RUNNING) self.add_repository_to_graph(path, is_dependency) repo_name = path.name files = list(path.rglob("*.py")) if path.is_dir() else [path] if job_id: self.job_manager.update_job(job_id, total_files=len(files)) debug_log("Starting pre-scan to build imports map...") imports_map = self._pre_scan_for_imports(files) debug_log(f"Pre-scan complete. Found {len(imports_map)} definitions.") all_function_calls_data = [] # Initialize list to collect all function call data processed_count = 0 for file in files: if file.is_file(): if job_id: self.job_manager.update_job(job_id, current_file=str(file)) repo_path = path.resolve() if path.is_dir() else file.parent.resolve() file_data = self.parse_python_file(repo_path, file, imports_map, is_dependency) if "error" not in file_data: self.add_file_to_graph(file_data, repo_name, imports_map) all_function_calls_data.append(file_data) # Collect for later processing processed_count += 1 if job_id: self.job_manager.update_job(job_id, processed_files=processed_count) await asyncio.sleep(0.01) # After all files are processed, create function call relationships self._create_all_function_calls(all_function_calls_data, imports_map) if debug_mode: with open("all_function_calls_data.json", "w") as f: import json json.dump(all_function_calls_data, f, indent=4) if job_id: self.job_manager.update_job(job_id, status=JobStatus.COMPLETED, end_time=datetime.now()) except Exception as e: logger.error(f"Failed to build graph for path {path}: {e}", exc_info=True) if job_id: self.job_manager.update_job( job_id, status=JobStatus.FAILED, end_time=datetime.now(), errors=[str(e)] ) def add_code_to_graph_tool( self, path: str, is_dependency: bool = False ) -> Dict[str, Any]: """Tool to add code to Neo4j graph with background processing""" try: path_obj = Path(path).resolve() if not path_obj.exists(): return {"error": f"Path {path} does not exist"} estimation = self.estimate_processing_time(path_obj) if estimation is None: return {"error": f"Could not analyze path {path}."} total_files, estimated_time = estimation job_id = self.job_manager.create_job(str(path_obj), is_dependency) self.job_manager.update_job( job_id, total_files=total_files, estimated_duration=estimated_time ) # Create the coroutine for the background task coro = self.build_graph_from_path_async(path_obj, is_dependency, job_id) # Safely schedule the coroutine to run on the main event loop from this thread asyncio.run_coroutine_threadsafe(coro, self.loop) debug_log(f"Started background job {job_id} for path: {str(path_obj)}") return { "success": True, "job_id": job_id, "message": f"Background processing started for {path_obj}", "estimated_files": total_files, "estimated_duration_seconds": round(estimated_time, 2), } except Exception as e: debug_log(f"Error creating background job: {str(e)}") return { "error": f"Failed to start background processing: {e.__class__.__name__}: {e}" }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Shashankss1205/CodeGraphContext'

If you have feedback or need assistance with the MCP directory API, please join our Discord server