treesitter_unified.py•23.6 kB
"""Unified Tree-sitter parser for multiple languages"""
import tree_sitter
import ctypes
import os
import site
import json
import re
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
from .base_parser import BaseParser
class TreeSitterUnifiedParser(BaseParser):
"""Unified parser using Tree-sitter for all supported languages"""
# Class-level cache for language objects and parsers
_language_cache = {}
_parser_cache = {}
_lib = None
# Language mapping
LANGUAGE_MAP = {
'.js': 'javascript',
'.jsx': 'javascript',
'.mjs': 'javascript',
'.ts': 'typescript',
'.tsx': 'tsx',
'.py': 'python',
'.go': 'go',
'.rs': 'rust',
'.java': 'java',
'.c': 'c',
'.cpp': 'cpp',
'.cc': 'cpp',
'.h': 'c',
'.hpp': 'cpp',
}
def __init__(self, language: str):
"""Initialize parser for specific language
Args:
language: Language name (e.g., 'javascript', 'typescript', 'python')
"""
self.language = language
self.parser = self._get_parser(language)
self.lang_obj = self._get_language(language)
@classmethod
def _load_lib(cls):
"""Load the tree-sitter-languages shared library"""
if cls._lib is None:
tsl_path = os.path.join(
site.getusersitepackages(),
'tree_sitter_languages',
'languages.so'
)
if not os.path.exists(tsl_path):
# Try system site-packages
for sp in site.getsitepackages():
alt_path = os.path.join(sp, 'tree_sitter_languages', 'languages.so')
if os.path.exists(alt_path):
tsl_path = alt_path
break
cls._lib = ctypes.CDLL(tsl_path)
return cls._lib
@classmethod
def _get_language(cls, language: str):
"""Get or create a tree-sitter Language object"""
if language not in cls._language_cache:
lib = cls._load_lib()
# Handle special cases
lang_name = language
if language == 'tsx':
lang_name = 'tsx'
elif language == 'cpp':
lang_name = 'cpp'
symbol_name = f'tree_sitter_{lang_name}'
lang_func = getattr(lib, symbol_name)
lang_func.restype = ctypes.c_void_p
lang_ptr = lang_func()
cls._language_cache[language] = tree_sitter.Language(lang_ptr)
return cls._language_cache[language]
@classmethod
def _get_parser(cls, language: str):
"""Get or create a parser for the language"""
if language not in cls._parser_cache:
parser = tree_sitter.Parser()
parser.language = cls._get_language(language)
cls._parser_cache[language] = parser
return cls._parser_cache[language]
def parse(self, code: str, file_path: Optional[Path] = None) -> tree_sitter.Tree:
"""Parse source code and return parse tree"""
if isinstance(code, str):
code = code.encode('utf-8')
return self.parser.parse(code)
def extract_symbols(self, code: str) -> List[Dict[str, Any]]:
"""Extract all symbols from code"""
symbols = []
symbols.extend(self.extract_functions(code))
symbols.extend(self.extract_classes(code))
# Add imports as symbols too for dependency tracking
for imp in self.extract_imports(code):
symbols.append({
'name': imp,
'type': 'import',
'line_num': 1, # Will be updated with actual line
'signature': f'import {imp}',
'docstring': ''
})
return symbols
def extract_functions(self, code: str) -> List[Dict[str, Any]]:
"""Extract function definitions from code"""
tree = self.parse(code)
functions = []
code_bytes = code.encode('utf-8')
# Language-specific function patterns
if self.language in ['javascript', 'typescript', 'tsx']:
patterns = [
('function_declaration', 'name'),
('function', 'name'),
('arrow_function', None), # Name comes from parent
('method_definition', 'name'),
('generator_function_declaration', 'name'),
]
elif self.language == 'python':
patterns = [
('function_definition', 'name'),
]
elif self.language == 'go':
patterns = [
('function_declaration', 'name'),
('method_declaration', 'name'),
]
elif self.language == 'rust':
patterns = [
('function_item', 'identifier'),
]
elif self.language == 'java':
patterns = [
('method_declaration', 'identifier'),
]
elif self.language in ['c', 'cpp']:
patterns = [
('function_definition', 'declarator'),
]
else:
patterns = []
def walk_tree(node, parent_name=None):
"""Walk the tree looking for function nodes"""
node_type = node.type
for pattern, name_field in patterns:
if node_type == pattern:
func_info = self._extract_function_info(
node, code_bytes, name_field, parent_name
)
if func_info:
functions.append(func_info)
break
# Check for arrow functions assigned to variables
if node_type == 'variable_declarator' and self.language in ['javascript', 'typescript', 'tsx']:
# Check if this is an arrow function assignment
for child in node.children:
if child.type == 'arrow_function':
# Get the variable name
name_node = node.child_by_field_name('name')
if name_node:
name = code_bytes[name_node.start_byte:name_node.end_byte].decode('utf-8')
func_info = self._extract_function_info(child, code_bytes, None, name)
if func_info:
functions.append(func_info)
# Recurse into children
for child in node.children:
walk_tree(child, parent_name)
walk_tree(tree.root_node)
return functions
def _extract_function_info(self, node, code_bytes, name_field=None, given_name=None):
"""Extract information about a function node"""
# Get function name
if given_name:
name = given_name
elif name_field:
name_node = node.child_by_field_name(name_field)
if name_node:
name = code_bytes[name_node.start_byte:name_node.end_byte].decode('utf-8')
else:
# Try to find identifier child
for child in node.children:
if 'identifier' in child.type or child.type == 'property_identifier':
name = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
break
else:
name = '<anonymous>'
else:
name = '<anonymous>'
# Get full signature
signature = code_bytes[node.start_byte:node.end_byte].decode('utf-8')
if len(signature) > 200:
# Get just the first line for long functions
signature = signature.split('\n')[0] + '...'
# Get parameters
params = self._extract_parameters(node, code_bytes)
# Get return type if available
return_type = self._extract_return_type(node, code_bytes)
return {
'name': name,
'type': 'function',
'line_num': node.start_point[0] + 1,
'end_line': node.end_point[0] + 1,
'signature': signature,
'docstring': self._extract_docstring(node, code_bytes),
'parameters': params,
'return_type': return_type
}
def _extract_parameters(self, node, code_bytes):
"""Extract function parameters"""
params = []
# Find parameters node
params_node = node.child_by_field_name('parameters')
if not params_node:
# Try to find formal_parameters or parameter_list
for child in node.children:
if 'parameter' in child.type:
params_node = child
break
if params_node:
for child in params_node.children:
if 'identifier' in child.type or 'parameter' in child.type:
param_text = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
# Parse out name and type if present
if ':' in param_text:
name, type_str = param_text.split(':', 1)
params.append({'name': name.strip(), 'type': type_str.strip()})
else:
params.append({'name': param_text.strip(), 'type': None})
return params
def _extract_return_type(self, node, code_bytes):
"""Extract function return type"""
# Look for return_type or type_annotation field
return_node = node.child_by_field_name('return_type')
if return_node:
return code_bytes[return_node.start_byte:return_node.end_byte].decode('utf-8')
# For TypeScript/Flow, look for type annotations
for child in node.children:
if child.type == 'type_annotation':
# Skip the ':' and get the type
type_text = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
if type_text.startswith(':'):
return type_text[1:].strip()
return type_text
return None
def _extract_docstring(self, node, code_bytes):
"""Extract docstring or JSDoc comment for a node"""
# Look for comment nodes before the function
prev_sibling = node.prev_sibling
if prev_sibling and 'comment' in prev_sibling.type:
comment = code_bytes[prev_sibling.start_byte:prev_sibling.end_byte].decode('utf-8')
# Clean up comment markers
if comment.startswith('/**'):
# JSDoc style
comment = comment[3:-2] if comment.endswith('*/') else comment[3:]
return comment.strip()
elif comment.startswith('//'):
return comment[2:].strip()
# For Python, look for docstring as first statement
if self.language == 'python' and node.type == 'function_definition':
body = node.child_by_field_name('body')
if body and body.children:
first_stmt = body.children[0]
if first_stmt.type == 'expression_statement':
for child in first_stmt.children:
if child.type == 'string':
docstring = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
# Remove quotes
return docstring.strip('"\'')
return ""
def extract_classes(self, code: str) -> List[Dict[str, Any]]:
"""Extract class definitions from code"""
tree = self.parse(code)
classes = []
code_bytes = code.encode('utf-8')
# Language-specific class patterns
if self.language in ['javascript', 'typescript', 'tsx']:
patterns = ['class_declaration', 'class']
elif self.language == 'python':
patterns = ['class_definition']
elif self.language == 'go':
patterns = ['type_declaration'] # Go uses type for structs
elif self.language == 'rust':
patterns = ['struct_item', 'impl_item']
elif self.language == 'java':
patterns = ['class_declaration']
elif self.language == 'cpp':
patterns = ['class_specifier', 'struct_specifier']
elif self.language == 'c':
patterns = ['struct_specifier']
else:
patterns = []
def walk_tree(node):
if node.type in patterns:
class_info = self._extract_class_info(node, code_bytes)
if class_info:
classes.append(class_info)
for child in node.children:
walk_tree(child)
walk_tree(tree.root_node)
return classes
def _extract_class_info(self, node, code_bytes):
"""Extract information about a class node"""
# Get class name
name_node = node.child_by_field_name('name')
if name_node:
name = code_bytes[name_node.start_byte:name_node.end_byte].decode('utf-8')
else:
# Try to find identifier
for child in node.children:
if 'identifier' in child.type:
name = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
break
else:
name = '<anonymous>'
# Get inheritance info
inherits = []
superclass = node.child_by_field_name('superclass')
if superclass:
inherits.append(code_bytes[superclass.start_byte:superclass.end_byte].decode('utf-8'))
# For implements/extends
for child in node.children:
if child.type in ['extends_clause', 'implements_clause']:
for subchild in child.children:
if 'identifier' in subchild.type or 'type' in subchild.type:
inherits.append(
code_bytes[subchild.start_byte:subchild.end_byte].decode('utf-8')
)
return {
'name': name,
'type': 'class',
'line_num': node.start_point[0] + 1,
'end_line': node.end_point[0] + 1,
'signature': f'class {name}',
'docstring': self._extract_docstring(node, code_bytes),
'inherits_from': inherits
}
def extract_imports(self, code: str) -> List[str]:
"""Extract import statements from code"""
tree = self.parse(code)
imports = []
code_bytes = code.encode('utf-8')
# Language-specific import patterns
if self.language in ['javascript', 'typescript', 'tsx']:
patterns = ['import_statement', 'import_clause']
def walk_tree(node):
if node.type == 'import_statement':
# Extract the module being imported
for child in node.children:
if child.type == 'string':
module = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
imports.append(module.strip('"\''))
elif node.type == 'call_expression':
# Check for require() calls
func = node.child_by_field_name('function')
if func and code_bytes[func.start_byte:func.end_byte].decode('utf-8') == 'require':
args = node.child_by_field_name('arguments')
if args:
for child in args.children:
if child.type == 'string':
module = code_bytes[child.start_byte:child.end_byte].decode('utf-8')
imports.append(module.strip('"\''))
for child in node.children:
walk_tree(child)
elif self.language == 'python':
def walk_tree(node):
if node.type in ['import_statement', 'import_from_statement']:
# Get the module name
module_node = node.child_by_field_name('module_name')
if module_node:
module = code_bytes[module_node.start_byte:module_node.end_byte].decode('utf-8')
imports.append(module)
else:
# Try to extract from the full import statement
import_text = code_bytes[node.start_byte:node.end_byte].decode('utf-8')
# Parse out the module name
if import_text.startswith('from '):
parts = import_text.split()
if len(parts) > 1:
imports.append(parts[1])
elif import_text.startswith('import '):
parts = import_text.split()
if len(parts) > 1:
imports.append(parts[1].split(',')[0])
for child in node.children:
walk_tree(child)
elif self.language == 'go':
def walk_tree(node):
if node.type == 'import_declaration':
for child in node.children:
if child.type == 'import_spec':
# Extract package path
for subchild in child.children:
if subchild.type == 'interpreted_string_literal':
pkg = code_bytes[subchild.start_byte:subchild.end_byte].decode('utf-8')
imports.append(pkg.strip('"'))
for child in node.children:
walk_tree(child)
else:
# Generic import extraction
def walk_tree(node):
if 'import' in node.type:
import_text = code_bytes[node.start_byte:node.end_byte].decode('utf-8')
imports.append(import_text)
for child in node.children:
walk_tree(child)
walk_tree(tree.root_node)
return imports
def extract_type_info(self, code: str, symbol_name: str = None) -> Dict[str, Any]:
"""Extract type information for symbols"""
tree = self.parse(code)
type_info = {
'parameters': [],
'return_type': None,
'exceptions_raised': [],
'type_annotations': {}
}
# This is handled per-function in extract_functions
# Aggregate type info from all functions if no specific symbol
functions = self.extract_functions(code)
if symbol_name:
for func in functions:
if func['name'] == symbol_name:
type_info['parameters'] = func.get('parameters', [])
type_info['return_type'] = func.get('return_type')
break
else:
# Aggregate all type info
all_params = []
all_returns = []
for func in functions:
all_params.extend(func.get('parameters', []))
if func.get('return_type'):
all_returns.append(func['return_type'])
type_info['parameters'] = all_params
type_info['return_type'] = ', '.join(set(all_returns)) if all_returns else None
return type_info
def extract_dependencies(self, code: str) -> Dict[str, Any]:
"""Extract dependency information"""
dependencies = {
'imports': self.extract_imports(code),
'calls': [],
'inherits_from': []
}
# Extract inheritance from classes
classes = self.extract_classes(code)
for cls in classes:
if cls.get('inherits_from'):
dependencies['inherits_from'].extend(cls['inherits_from'])
# Extract function calls (simplified - just function names called)
tree = self.parse(code)
code_bytes = code.encode('utf-8')
def walk_for_calls(node):
if node.type == 'call_expression':
func_node = node.child_by_field_name('function')
if func_node:
if func_node.type == 'identifier':
call_name = code_bytes[func_node.start_byte:func_node.end_byte].decode('utf-8')
dependencies['calls'].append(call_name)
elif func_node.type == 'member_expression':
# Handle method calls like obj.method()
prop = func_node.child_by_field_name('property')
if prop:
call_name = code_bytes[prop.start_byte:prop.end_byte].decode('utf-8')
dependencies['calls'].append(call_name)
for child in node.children:
walk_for_calls(child)
walk_for_calls(tree.root_node)
# Remove duplicates
dependencies['calls'] = list(set(dependencies['calls']))
dependencies['inherits_from'] = list(set(dependencies['inherits_from']))
return dependencies
def extract_documentation(self, code: str) -> Dict[str, Any]:
"""Extract documentation and comments"""
tree = self.parse(code)
code_bytes = code.encode('utf-8')
docs = {
'todo_items': [],
'inline_comments': [],
'docstrings': []
}
# Walk tree looking for comments
def walk_for_comments(node):
if 'comment' in node.type:
comment_text = code_bytes[node.start_byte:node.end_byte].decode('utf-8')
# Check for TODO/FIXME/XXX/HACK/NOTE
todo_pattern = r'(TODO|FIXME|XXX|HACK|NOTE)[:\s]+(.+)'
match = re.search(todo_pattern, comment_text, re.IGNORECASE)
if match:
docs['todo_items'].append({
'type': match.group(1).upper(),
'text': match.group(2).strip(),
'line': node.start_point[0] + 1
})
else:
# Regular comment
cleaned = comment_text.strip('/*# \n')
if cleaned:
docs['inline_comments'].append(cleaned)
# Check if it's a docstring-style comment
if comment_text.startswith('/**') or comment_text.startswith('"""'):
docs['docstrings'].append(comment_text.strip('/*" \n'))
for child in node.children:
walk_for_comments(child)
walk_for_comments(tree.root_node)
return docs