utils.rs•14.4 kB
//! Utility functions for name extraction and tree-sitter node handling
use tree_sitter::Node;
/// Utilities for extracting names and identifiers from tree-sitter nodes
pub struct NameExtractor;
impl NameExtractor {
    /// Extract name from a tree-sitter node by finding identifier children
    pub fn extract_name_from_node(node: Node<'_>, content: &str) -> Result<String, String> {
        // Try to find identifier node recursively
        if let Some(name) = Self::find_identifier_recursive(node, content) {
            return Ok(name);
        }
        Ok(String::new())
    }
    /// Find identifier recursively in the node tree
    pub fn find_identifier_recursive(node: Node<'_>, content: &str) -> Option<String> {
        Self::find_identifier_recursive_impl(node, content)
    }
    /// Internal implementation of recursive identifier finding
    fn find_identifier_recursive_impl(node: Node<'_>, content: &str) -> Option<String> {
        // Check if this node is an identifier
        match node.kind() {
            "identifier" | "property_identifier" | "type_identifier" => {
                let start_byte = node.start_byte();
                let end_byte = node.end_byte();
                if let Some(name) = content.get(start_byte..end_byte) {
                    return Some(name.to_string());
                }
            }
            _ => {}
        }
        // Search children recursively (but limit depth to avoid infinite recursion)
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if let Some(name) = Self::find_identifier_recursive_impl(child, content) {
                return Some(name);
            }
        }
        None
    }
    /// Find a child node by its kind
    pub fn find_child_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if child.kind() == kind {
                return Some(child);
            }
            // Also search recursively in children
            if let Some(found) = Self::find_child_by_kind(child, kind) {
                return Some(found);
            }
        }
        None
    }
    /// Collect all identifiers from a node and its children
    pub fn collect_identifiers_from_node(node: Node<'_>, content: &str) -> Vec<String> {
        let mut identifiers = Vec::new();
        Self::collect_identifiers_recursive(node, content, &mut identifiers);
        identifiers
    }
    /// Recursively collect identifiers
    fn collect_identifiers_recursive(node: Node<'_>, content: &str, identifiers: &mut Vec<String>) {
        if node.kind() == "identifier" {
            let start_byte = node.start_byte();
            let end_byte = node.end_byte();
            if let Some(name) = content.get(start_byte..end_byte) {
                identifiers.push(name.to_string());
            }
        }
        
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            Self::collect_identifiers_recursive(child, content, identifiers);
        }
    }
    /// Extract text content from a node
    pub fn extract_node_text<'a>(node: Node<'_>, content: &'a str) -> Option<&'a str> {
        content.get(node.start_byte()..node.end_byte())
    }
    /// Check if a node represents a named construct (has identifier children)
    pub fn is_named_construct(node: Node<'_>) -> bool {
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if matches!(child.kind(), "identifier" | "property_identifier" | "type_identifier") {
                return true;
            }
        }
        false
    }
    /// Get the line and column position of a node
    pub fn get_position_info(node: Node<'_>) -> (u32, u32, u32, u32) {
        (
            node.start_position().row as u32 + 1,
            node.start_position().column as u32 + 1,
            node.end_position().row as u32 + 1,
            node.end_position().column as u32 + 1,
        )
    }
    /// Check if a node or any of its children have errors
    pub fn has_syntax_errors(node: Node<'_>) -> bool {
        if node.is_error() || node.is_missing() {
            return true;
        }
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if Self::has_syntax_errors(child) {
                return true;
            }
        }
        false
    }
    /// Count the number of children of a specific kind
    pub fn count_children_by_kind(node: Node<'_>, kind: &str) -> usize {
        let mut cursor = node.walk();
        node.children(&mut cursor)
            .filter(|child| child.kind() == kind)
            .count()
    }
    /// Get all direct children of a node
    pub fn get_direct_children(node: Node<'_>) -> Vec<Node<'_>> {
        let mut cursor = node.walk();
        node.children(&mut cursor).collect()
    }
    /// Check if a string is a valid programming language identifier
    pub fn is_valid_identifier(name: &str) -> bool {
        !name.is_empty()
            && name.chars().next().is_some_and(|c| c.is_alphabetic() || c == '_')
            && name.chars().all(|c| c.is_alphanumeric() || c == '_')
    }
    /// Sanitize an extracted name by removing invalid characters
    pub fn sanitize_name(name: &str) -> String {
        name.chars()
            .filter(|c| c.is_alphanumeric() || *c == '_')
            .collect()
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::parsing::ParserManager;
    fn create_test_tree_and_code() -> (tree_sitter::Tree, String) {
        let mut manager = ParserManager::new().unwrap();
        let code = "function calculateTotal(price, tax) { return price + tax; }".to_string();
        let tree = manager.parse(&code, "javascript").unwrap();
        (tree, code)
    }
    #[test]
    fn test_extract_name_from_node() {
        let (tree, code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        // Find the function declaration
        let mut cursor = root.walk();
        let function_node = root.children(&mut cursor)
            .find(|child| child.kind() == "function_declaration")
            .unwrap();
        
        let name = NameExtractor::extract_name_from_node(function_node, &code);
        assert!(name.is_ok());
        assert_eq!(name.unwrap(), "calculateTotal");
    }
    #[test]
    fn test_find_identifier_recursive() {
        let (tree, code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let identifier = NameExtractor::find_identifier_recursive(root, &code);
        assert!(identifier.is_some());
        assert_eq!(identifier.unwrap(), "calculateTotal");
    }
    #[test]
    fn test_find_child_by_kind() {
        let (tree, _code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let function_node = NameExtractor::find_child_by_kind(root, "function_declaration");
        assert!(function_node.is_some());
        assert_eq!(function_node.unwrap().kind(), "function_declaration");
        
        let nonexistent = NameExtractor::find_child_by_kind(root, "nonexistent_kind");
        assert!(nonexistent.is_none());
    }
    #[test]
    fn test_collect_identifiers_from_node() {
        let (tree, code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let identifiers = NameExtractor::collect_identifiers_from_node(root, &code);
        
        // Should find function name and parameter names
        assert!(identifiers.contains(&"calculateTotal".to_string()));
        assert!(identifiers.contains(&"price".to_string()));
        assert!(identifiers.contains(&"tax".to_string()));
    }
    #[test]
    fn test_extract_node_text() {
        let (tree, code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let text = NameExtractor::extract_node_text(root, &code);
        assert!(text.is_some());
        assert_eq!(text.unwrap(), code);
    }
    #[test]
    fn test_is_named_construct() {
        let (tree, _code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        // Root program node should have named constructs
        assert!(NameExtractor::is_named_construct(root));
        
        // Find function declaration which should also be a named construct
        let function_node = NameExtractor::find_child_by_kind(root, "function_declaration");
        assert!(function_node.is_some());
        assert!(NameExtractor::is_named_construct(function_node.unwrap()));
    }
    #[test]
    fn test_get_position_info() {
        let (tree, _code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let (start_row, start_col, end_row, end_col) = NameExtractor::get_position_info(root);
        
        // Root should start at line 1, column 1
        assert_eq!(start_row, 1);
        assert_eq!(start_col, 1);
        assert!(end_row >= start_row);
        assert!(end_col >= start_col);
    }
    #[test]
    fn test_has_syntax_errors() {
        let mut manager = ParserManager::new().unwrap();
        
        // Valid code should have no errors
        let valid_code = "function test() { return 42; }";
        let valid_tree = manager.parse(valid_code, "javascript").unwrap();
        assert!(!NameExtractor::has_syntax_errors(valid_tree.root_node()));
        
        // Invalid code should have errors
        let invalid_code = "function {{{ invalid syntax";
        let invalid_tree = manager.parse(invalid_code, "javascript").unwrap();
        assert!(NameExtractor::has_syntax_errors(invalid_tree.root_node()));
    }
    #[test]
    fn test_count_children_by_kind() {
        let (tree, _code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let function_count = NameExtractor::count_children_by_kind(root, "function_declaration");
        assert_eq!(function_count, 1);
        
        let nonexistent_count = NameExtractor::count_children_by_kind(root, "nonexistent");
        assert_eq!(nonexistent_count, 0);
    }
    #[test]
    fn test_get_direct_children() {
        let (tree, _code) = create_test_tree_and_code();
        let root = tree.root_node();
        
        let children = NameExtractor::get_direct_children(root);
        assert!(children.len() > 0);
        
        // Should have a function declaration as a child
        let has_function = children.iter().any(|child| child.kind() == "function_declaration");
        assert!(has_function);
    }
    #[test]
    fn test_is_valid_identifier() {
        // Valid identifiers
        assert!(NameExtractor::is_valid_identifier("hello"));
        assert!(NameExtractor::is_valid_identifier("_private"));
        assert!(NameExtractor::is_valid_identifier("camelCase"));
        assert!(NameExtractor::is_valid_identifier("snake_case"));
        assert!(NameExtractor::is_valid_identifier("PascalCase"));
        assert!(NameExtractor::is_valid_identifier("a123"));
        assert!(NameExtractor::is_valid_identifier("_"));
        
        // Invalid identifiers
        assert!(!NameExtractor::is_valid_identifier(""));
        assert!(!NameExtractor::is_valid_identifier("123abc"));
        assert!(!NameExtractor::is_valid_identifier("hello-world"));
        assert!(!NameExtractor::is_valid_identifier("hello.world"));
        assert!(!NameExtractor::is_valid_identifier("hello world"));
        assert!(!NameExtractor::is_valid_identifier("@special"));
    }
    #[test]
    fn test_sanitize_name() {
        assert_eq!(NameExtractor::sanitize_name("hello"), "hello");
        assert_eq!(NameExtractor::sanitize_name("hello-world"), "helloworld");
        assert_eq!(NameExtractor::sanitize_name("hello.world"), "helloworld");
        assert_eq!(NameExtractor::sanitize_name("hello world"), "helloworld");
        assert_eq!(NameExtractor::sanitize_name("hello123"), "hello123");
        assert_eq!(NameExtractor::sanitize_name("hello_world"), "hello_world");
        assert_eq!(NameExtractor::sanitize_name("@#$%"), "");
    }
    #[test]
    fn test_complex_javascript_structure() {
        let mut manager = ParserManager::new().unwrap();
        let complex_code = r#"
            class Calculator {
                constructor(name) {
                    this.name = name;
                }
                
                add(a, b) {
                    return a + b;
                }
            }
        "#;
        
        let tree = manager.parse(complex_code, "javascript").unwrap();
        let root = tree.root_node();
        
        let identifiers = NameExtractor::collect_identifiers_from_node(root, complex_code);
        
        // Should find class name, constructor, method name, and parameter names
        assert!(identifiers.contains(&"Calculator".to_string()));
        assert!(identifiers.contains(&"constructor".to_string()));
        assert!(identifiers.contains(&"name".to_string()));
        assert!(identifiers.contains(&"add".to_string()));
        assert!(identifiers.contains(&"a".to_string()));
        assert!(identifiers.contains(&"b".to_string()));
    }
    #[test]
    fn test_empty_node_handling() {
        let mut manager = ParserManager::new().unwrap();
        let empty_code = "";
        let tree = manager.parse(empty_code, "javascript").unwrap();
        let root = tree.root_node();
        
        let name = NameExtractor::extract_name_from_node(root, empty_code);
        assert!(name.is_ok());
        assert_eq!(name.unwrap(), "");
        
        let identifiers = NameExtractor::collect_identifiers_from_node(root, empty_code);
        assert_eq!(identifiers.len(), 0);
    }
    #[test]
    fn test_typescript_types() {
        let mut manager = ParserManager::new().unwrap();
        let ts_code = "interface UserService { getName(): string; }";
        let tree = manager.parse(ts_code, "typescript").unwrap();
        let root = tree.root_node();
        
        let identifiers = NameExtractor::collect_identifiers_from_node(root, ts_code);
        
        // Should find interface name and method name
        assert!(identifiers.contains(&"UserService".to_string()));
        assert!(identifiers.contains(&"getName".to_string()));
    }
}