test_symbol_extraction.py•24.2 kB
"""
Tests for symbol extraction and dependency analysis issues.
This module contains tests specifically focused on the symbol extraction and
dependency analysis issues identified in FEATURES.md.
"""
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Generator
import pytest
from tests.test_helpers import (
get_ast,
get_dependencies,
get_symbols,
register_project_tool,
)
@pytest.fixture
def test_project(request) -> Generator[Dict[str, Any], None, None]:
"""Create a test project with Python files containing known symbols and imports."""
with tempfile.TemporaryDirectory() as temp_dir:
project_path = Path(temp_dir)
# Create a Python file with known symbols and dependencies
test_file = project_path / "test.py"
with open(test_file, "w") as f:
f.write(
"""
import os
import sys
from typing import List, Dict, Optional
from datetime import datetime as dt
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def greet(self) -> str:
return f"Hello, my name is {self.name} and I'm {self.age} years old."
class Employee(Person):
def __init__(self, name: str, age: int, employee_id: str):
super().__init__(name, age)
self.employee_id = employee_id
def greet(self) -> str:
basic_greeting = super().greet()
return f"{basic_greeting} I am employee {self.employee_id}."
def process_data(items: List[str]) -> Dict[str, int]:
result = {}
for item in items:
result[item] = len(item)
return result
def calculate_age(birthdate: dt) -> int:
today = dt.now()
age = today.year - birthdate.year
if (today.month, today.day) < (birthdate.month, birthdate.day):
age -= 1
return age
if __name__ == "__main__":
p = Person("Alice", 30)
e = Employee("Bob", 25, "E12345")
print(p.greet())
print(e.greet())
data = process_data(["apple", "banana", "cherry"])
print(data)
bob_birthday = dt(1998, 5, 15)
bob_age = calculate_age(bob_birthday)
print(f"Bob's age is {bob_age}")
"""
)
# Create a second file with additional imports and symbols
utils_file = project_path / "utils.py"
with open(utils_file, "w") as f:
f.write(
"""
import json
import csv
import random
from typing import Any, List, Dict, Tuple
from pathlib import Path
def save_json(data: Dict[str, Any], filename: str) -> None:
with open(filename, 'w') as f:
json.dump(data, f, indent=2)
def load_json(filename: str) -> Dict[str, Any]:
with open(filename, 'r') as f:
return json.load(f)
def generate_random_data(count: int) -> List[Dict[str, Any]]:
result = []
for i in range(count):
person = {
"id": i,
"name": f"Person {i}",
"age": random.randint(18, 80),
"active": random.choice([True, False])
}
result.append(person)
return result
class FileHandler:
def __init__(self, base_path: str):
self.base_path = Path(base_path)
def save_data(self, data: Dict[str, Any], filename: str) -> str:
file_path = self.base_path / filename
save_json(data, str(file_path))
return str(file_path)
def load_data(self, filename: str) -> Dict[str, Any]:
file_path = self.base_path / filename
return load_json(str(file_path))
"""
)
# Generate a unique project name based on the test name
test_name = request.node.name
unique_id = abs(hash(test_name)) % 10000
project_name = f"symbol_test_project_{unique_id}"
# Register project
try:
register_project_tool(path=str(project_path), name=project_name)
except Exception:
# If registration fails, try with an even more unique name
import time
project_name = f"symbol_test_project_{unique_id}_{int(time.time())}"
register_project_tool(path=str(project_path), name=project_name)
yield {
"name": project_name,
"path": str(project_path),
"files": ["test.py", "utils.py"],
}
def test_symbol_extraction_diagnostics(test_project) -> None:
"""Test symbol extraction to diagnose specific issues in the implementation."""
# Get symbols from first file, excluding class methods
symbols = get_symbols(project=test_project["name"], file_path="test.py")
# Also get symbols with class methods excluded for comparison
from mcp_server_tree_sitter.api import get_language_registry, get_project_registry
from mcp_server_tree_sitter.tools.analysis import extract_symbols
project = get_project_registry().get_project(test_project["name"])
language_registry = get_language_registry()
symbols_excluding_methods = extract_symbols(project, "test.py", language_registry, exclude_class_methods=True)
# Verify the result structure
assert "functions" in symbols, "Result should contain 'functions' key"
assert "classes" in symbols, "Result should contain 'classes' key"
assert "imports" in symbols, "Result should contain 'imports' key"
# Print diagnostic information
print("\nSymbol extraction results for test.py:")
print(f"Functions: {symbols['functions']}")
print(f"Functions (excluding methods): {symbols_excluding_methods['functions']}")
print(f"Classes: {symbols['classes']}")
print(f"Imports: {symbols['imports']}")
# Check symbol counts
expected_function_count = 2 # process_data, calculate_age
expected_class_count = 2 # Person, Employee
expected_import_count = 4 # os, sys, typing, datetime
# Verify extracted symbols
if symbols_excluding_methods["functions"] and len(symbols_excluding_methods["functions"]) > 0:
# Instead of checking exact counts, just verify we found the main functions
function_names = [f["name"] for f in symbols_excluding_methods["functions"]]
# Check for process_data function - handle both bytes and strings
process_data_found = False
for name in function_names:
if (isinstance(name, bytes) and b"process_data" in name) or (
isinstance(name, str) and "process_data" in name
):
process_data_found = True
break
# Check for calculate_age function - handle both bytes and strings
calculate_age_found = False
for name in function_names:
if (isinstance(name, bytes) and b"calculate_age" in name) or (
isinstance(name, str) and "calculate_age" in name
):
calculate_age_found = True
break
assert process_data_found, "Expected to find 'process_data' function"
assert calculate_age_found, "Expected to find 'calculate_age' function"
else:
print(f"KNOWN ISSUE: Expected {expected_function_count} functions, but got empty list")
if symbols["classes"] and len(symbols["classes"]) > 0:
assert len(symbols["classes"]) == expected_class_count
else:
print(f"KNOWN ISSUE: Expected {expected_class_count} classes, but got empty list")
if symbols["imports"] and len(symbols["imports"]) > 0:
# Our improved import detection now finds individual import names plus the statements
# So we'll just check that we found all expected import modules
import_texts = [imp.get("name", "") for imp in symbols["imports"]]
for module in ["os", "sys", "typing", "datetime"]:
assert any(
(isinstance(text, bytes) and module.encode() in text) or (isinstance(text, str) and module in text)
for text in import_texts
), f"Should find '{module}' import"
else:
print(f"KNOWN ISSUE: Expected {expected_import_count} imports, but got empty list")
# Now check the second file to ensure results are consistent
symbols_utils = get_symbols(project=test_project["name"], file_path="utils.py")
print("\nSymbol extraction results for utils.py:")
print(f"Functions: {symbols_utils['functions']}")
print(f"Classes: {symbols_utils['classes']}")
print(f"Imports: {symbols_utils['imports']}")
def test_dependency_analysis_diagnostics(test_project) -> None:
"""Test dependency analysis to diagnose specific issues in the implementation."""
# Get dependencies from the first file
dependencies = get_dependencies(project=test_project["name"], file_path="test.py")
# Print diagnostic information
print("\nDependency analysis results for test.py:")
print(f"Dependencies: {dependencies}")
# Expected dependencies based on imports
expected_dependencies = ["os", "sys", "typing", "datetime"]
# Check dependencies that should be found
if dependencies and len(dependencies) > 0:
# If we have a module list, check against that directly
if "module" in dependencies:
# Modify test to be more flexible with datetime imports
for dep in ["os", "sys", "typing"]:
assert any(
(isinstance(mod, bytes) and dep.encode() in mod) or (isinstance(mod, str) and dep in mod)
for mod in dependencies["module"]
), f"Expected dependency '{dep}' not found"
else:
# Otherwise check in the entire dependencies dictionary
for dep in expected_dependencies:
assert dep in str(dependencies), f"Expected dependency '{dep}' not found"
else:
print(f"KNOWN ISSUE: Expected dependencies {expected_dependencies}, but got empty result")
# Check the second file for consistency
dependencies_utils = get_dependencies(project=test_project["name"], file_path="utils.py")
print("\nDependency analysis results for utils.py:")
print(f"Dependencies: {dependencies_utils}")
def test_symbol_extraction_with_ast_access(test_project) -> None:
"""Test symbol extraction with direct AST access to identify where processing breaks."""
# Get the AST for the file
ast_result = get_ast(
project=test_project["name"],
path="test.py",
max_depth=10, # Deep enough to capture all relevant nodes
include_text=True,
)
# Verify the AST is properly formed
assert "tree" in ast_result, "AST result should contain 'tree'"
# Extract the tree structure for analysis
tree = ast_result["tree"]
# Manually search for symbols in the AST
functions = []
classes = []
imports = []
def extract_symbols_manually(node, path=()) -> None:
"""Recursively extract symbols from the AST."""
if not isinstance(node, dict):
return
node_type = node.get("type")
# Identify function definitions
if node_type == "function_definition":
# Find the name node which is usually a direct child with type 'identifier'
if "children" in node:
for child in node["children"]:
if child.get("type") == "identifier":
functions.append(
{
"name": child.get("text"),
"path": path,
"node_id": node.get("id"),
"text": node.get("text", "").split("\n")[0][:50], # First line, truncated
}
)
break
# Identify class definitions
elif node_type == "class_definition":
# Find the name node
if "children" in node:
for child in node["children"]:
if child.get("type") == "identifier":
classes.append(
{
"name": child.get("text"),
"path": path,
"node_id": node.get("id"),
"text": node.get("text", "").split("\n")[0][:50], # First line, truncated
}
)
break
# Identify imports
elif node_type in ("import_statement", "import_from_statement"):
imports.append(
{
"type": node_type,
"path": path,
"node_id": node.get("id"),
"text": node.get("text", "").split("\n")[0], # First line
}
)
# Recurse into children
if "children" in node:
for i, child in enumerate(node["children"]):
extract_symbols_manually(child, path + (i,))
# Extract symbols from the AST
extract_symbols_manually(tree)
# Print diagnostic information
print("\nManual symbol extraction results:")
print(f"Functions found: {len(functions)}")
for func in functions:
print(f" {func['name']} - {func['text']}")
print(f"Classes found: {len(classes)}")
for cls in classes:
print(f" {cls['name']} - {cls['text']}")
print(f"Imports found: {len(imports)}")
for imp in imports:
print(f" {imp['type']} - {imp['text']}")
# Expected counts
assert len(functions) > 0, "Should find at least one function by manual extraction"
assert len(classes) > 0, "Should find at least one class by manual extraction"
assert len(imports) > 0, "Should find at least one import by manual extraction"
# Compare with get_symbols results
symbols = get_symbols(project=test_project["name"], file_path="test.py")
print("\nComparison with get_symbols:")
print(f"Manual functions: {len(functions)}, get_symbols: {len(symbols['functions'])}")
print(f"Manual classes: {len(classes)}, get_symbols: {len(symbols['classes'])}")
print(f"Manual imports: {len(imports)}, get_symbols: {len(symbols['imports'])}")
def test_query_based_symbol_extraction(test_project) -> None:
"""
Test symbol extraction using direct tree-sitter queries to identify issues.
This test demonstrates how query-based symbol extraction should work,
which can help identify where the implementation breaks down.
"""
try:
# Import necessary components for direct query execution
from tree_sitter import Parser, Query
from tree_sitter_language_pack import get_language
# Get Python language
language_obj = get_language("python")
# Create a parser
parser = Parser()
try:
# Try set_language method first
parser.set_language(language_obj) # type: ignore
except (AttributeError, TypeError):
# Fall back to setting language property
parser.language = language_obj
# Read the file content
file_path = os.path.join(test_project["path"], "test.py")
with open(file_path, "rb") as f:
content = f.read()
# Parse the content
tree = parser.parse(content)
# Define queries for different symbol types
function_query = """
(function_definition
name: (identifier) @function.name
parameters: (parameters) @function.params
body: (block) @function.body
) @function.def
"""
class_query = """
(class_definition
name: (identifier) @class.name
body: (block) @class.body
) @class.def
"""
import_query = """
(import_statement
name: (dotted_name) @import.module
) @import
(import_from_statement
module_name: (dotted_name) @import.from
name: (dotted_name) @import.item
) @import
"""
# Run the queries
functions_q = Query(language_obj, function_query)
classes_q = Query(language_obj, class_query)
imports_q = Query(language_obj, import_query)
function_captures = functions_q.captures(tree.root_node)
class_captures = classes_q.captures(tree.root_node)
import_captures = imports_q.captures(tree.root_node)
# Process and extract unique symbols
functions: Dict[str, Dict[str, Any]] = {}
classes: Dict[str, Dict[str, Any]] = {}
imports: Dict[str, Dict[str, Any]] = {}
# Helper function to process captures with different formats
def process_capture(captures, target_type, result_dict) -> None:
# Check if it's returning a dictionary format
if isinstance(captures, dict):
# Dictionary format: {capture_name: [node1, node2, ...], ...}
for capture_name, nodes in captures.items():
if capture_name == target_type:
for node in nodes:
name = node.text.decode("utf-8") if hasattr(node.text, "decode") else str(node.text)
result_dict[name] = {
"name": name,
"start": node.start_point,
"end": node.end_point,
}
else:
# Assume it's a list of matches
try:
# Try different formats
for item in captures:
# Could be tuple, object, or dict
if isinstance(item, tuple):
if len(item) == 2:
node, capture_name = item
else:
continue # Skip if unexpected tuple size
elif hasattr(item, "node") and hasattr(item, "capture_name"):
node, capture_name = item.node, item.capture_name
elif isinstance(item, dict) and "node" in item and "capture" in item:
node, capture_name = item["node"], item["capture"]
else:
continue # Skip if format unknown
if capture_name == target_type:
name = node.text.decode("utf-8") if hasattr(node.text, "decode") else str(node.text)
result_dict[name] = {
"name": name,
"start": node.start_point,
"end": node.end_point,
}
except Exception as e:
print(f"Error processing captures: {str(e)}")
# Process each type of capture
process_capture(function_captures, "function.name", functions)
process_capture(class_captures, "class.name", classes)
# For imports, use a separate function since the comparison is different
def process_import_capture(captures) -> None:
# Check if it's returning a dictionary format
if isinstance(captures, dict):
# Dictionary format: {capture_name: [node1, node2, ...], ...}
for capture_name, nodes in captures.items():
if capture_name in ("import.module", "import.from", "import.item"):
for node in nodes:
name = node.text.decode("utf-8") if hasattr(node.text, "decode") else str(node.text)
imports[name] = {
"name": name,
"type": capture_name,
"start": node.start_point,
"end": node.end_point,
}
else:
# Assume it's a list of matches
try:
# Try different formats
for item in captures:
# Could be tuple, object, or dict
if isinstance(item, tuple):
if len(item) == 2:
node, capture_name = item
else:
continue # Skip if unexpected tuple size
elif hasattr(item, "node") and hasattr(item, "capture_name"):
node, capture_name = item.node, item.capture_name
elif isinstance(item, dict) and "node" in item and "capture" in item:
node, capture_name = item["node"], item["capture"]
else:
continue # Skip if format unknown
if capture_name in (
"import.module",
"import.from",
"import.item",
):
name = node.text.decode("utf-8") if hasattr(node.text, "decode") else str(node.text)
imports[name] = {
"name": name,
"type": capture_name,
"start": node.start_point,
"end": node.end_point,
}
except Exception as e:
print(f"Error processing import captures: {str(e)}")
# Call the import capture processing function
process_import_capture(import_captures)
# Print the direct query results
print("\nDirect query results:")
print(f"Functions: {list(functions.keys())}")
print(f"Classes: {list(classes.keys())}")
print(f"Imports: {list(imports.keys())}")
# Compare with get_symbols
symbols = get_symbols(project=test_project["name"], file_path="test.py")
print("\nComparison with get_symbols:")
print(f"Query functions: {len(functions)}, get_symbols: {len(symbols['functions'])}")
print(f"Query classes: {len(classes)}, get_symbols: {len(symbols['classes'])}")
print(f"Query imports: {len(imports)}, get_symbols: {len(symbols['imports'])}")
# Document any differences that might indicate where the issue lies
if len(functions) != len(symbols["functions"]):
print("ISSUE: Function count mismatch")
if len(classes) != len(symbols["classes"]):
print("ISSUE: Class count mismatch")
if len(imports) != len(symbols["imports"]):
print("ISSUE: Import count mismatch")
except Exception as e:
print(f"Error in direct query execution: {str(e)}")
pytest.fail(f"Direct query execution failed: {str(e)}")
def test_debug_file_saving(test_project) -> None:
"""Save debug information to files for further analysis."""
# Create a debug directory
debug_dir = os.path.join(test_project["path"], "debug")
os.makedirs(debug_dir, exist_ok=True)
# Get AST and symbol information
ast_result = get_ast(project=test_project["name"], path="test.py", max_depth=10, include_text=True)
symbols = get_symbols(project=test_project["name"], file_path="test.py")
dependencies = get_dependencies(project=test_project["name"], file_path="test.py")
# Define a custom JSON encoder for bytes objects
class BytesEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return super().default(obj)
# Save the information to files
with open(os.path.join(debug_dir, "ast.json"), "w") as f:
json.dump(ast_result, f, indent=2, cls=BytesEncoder)
with open(os.path.join(debug_dir, "symbols.json"), "w") as f:
json.dump(symbols, f, indent=2, cls=BytesEncoder)
with open(os.path.join(debug_dir, "dependencies.json"), "w") as f:
json.dump(dependencies, f, indent=2, cls=BytesEncoder)
print(f"\nDebug information saved to {debug_dir}")