"""
Python code analyzer for Neo4j graph extraction.
Analyzes Python source files using AST to extract:
- Classes with attributes and methods
- Functions with parameters and return types
- Import relationships
"""
import ast
import logging
from pathlib import Path
from typing import Any, cast
from src.core.exceptions import AnalysisError, ParsingError
# Configure logging
logger = logging.getLogger(__name__)
class Neo4jCodeAnalyzer:
"""Analyzes code for direct Neo4j insertion"""
def __init__(self) -> None:
# External modules to ignore
self.external_modules = {
# Python standard library
"os",
"sys",
"json",
"logging",
"datetime",
"pathlib",
"typing",
"collections",
"asyncio",
"subprocess",
"ast",
"re",
"string",
"urllib",
"http",
"email",
"time",
"uuid",
"hashlib",
"base64",
"itertools",
"functools",
"operator",
"contextlib",
"copy",
"pickle",
"tempfile",
"shutil",
"glob",
"fnmatch",
"io",
"codecs",
"locale",
"platform",
"socket",
"ssl",
"threading",
"queue",
"multiprocessing",
"concurrent",
"warnings",
"traceback",
"inspect",
"importlib",
"pkgutil",
"types",
"weakref",
"gc",
"dataclasses",
"enum",
"abc",
"numbers",
"decimal",
"fractions",
"math",
"cmath",
"random",
"statistics",
# Common third-party libraries
"requests",
"urllib3",
"httpx",
"aiohttp",
"flask",
"django",
"fastapi",
"pydantic",
"sqlalchemy",
"alembic",
"psycopg2",
"pymongo",
"redis",
"celery",
"pytest",
"unittest",
"mock",
"faker",
"factory",
"hypothesis",
"numpy",
"pandas",
"matplotlib",
"seaborn",
"scipy",
"sklearn",
"torch",
"tensorflow",
"keras",
"opencv",
"pillow",
"boto3",
"botocore",
"azure",
"google",
"openai",
"anthropic",
"langchain",
"transformers",
"huggingface_hub",
"click",
"typer",
"rich",
"colorama",
"tqdm",
"python-dotenv",
"pyyaml",
"toml",
"configargparse",
"marshmallow",
"attrs",
"dataclasses-json",
"jsonschema",
"cerberus",
"voluptuous",
"schema",
"jinja2",
"mako",
"cryptography",
"bcrypt",
"passlib",
"jwt",
"authlib",
"oauthlib",
}
def analyze_python_file(
self,
file_path: Path,
repo_root: Path,
project_modules: set[str],
) -> dict[str, Any] | None:
"""Extract structure for direct Neo4j insertion"""
try:
with file_path.open(encoding="utf-8") as f:
content = f.read()
tree = ast.parse(content)
relative_path = str(file_path.relative_to(repo_root))
module_name = self._get_importable_module_name(
repo_root,
relative_path,
)
# Extract structure
classes: list[dict[str, Any]] = []
functions: list[dict[str, Any]] = []
imports: list[str] = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Extract class with its methods and comprehensive attributes
methods = []
for item in node.body:
if isinstance(
item, ast.FunctionDef | ast.AsyncFunctionDef
) and not item.name.startswith("_"):
# Public methods only
# Extract comprehensive parameter info
params = self._extract_function_parameters(item)
# Get return type annotation
return_type = (
self._get_name(item.returns) if item.returns else "Any"
)
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p["optional"] and p["default"] is not None:
param_str += f"={p['default']}"
elif p["optional"]:
param_str += "=None"
if p["kind"] != "positional":
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
# Backwards compatibility arg list
arg_list = [
arg.arg for arg in item.args.args if arg.arg != "self"
]
methods.append(
{
"name": item.name,
"params": params,
"params_detailed": params_detailed,
"return_type": return_type,
"args": arg_list,
}
)
# Use comprehensive attribute extraction
attributes = self._extract_class_attributes(node)
classes.append(
{
"name": node.name,
"full_name": f"{module_name}.{node.name}",
"methods": methods,
"attributes": attributes,
}
)
elif isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
# Only top-level functions
is_in_class = any(
node in cls_node.body
for cls_node in ast.walk(tree)
if isinstance(cls_node, ast.ClassDef)
)
if not is_in_class and not node.name.startswith("_"):
# Extract comprehensive parameter info
params = self._extract_function_parameters(node)
# Get return type annotation
return_type = (
self._get_name(node.returns) if node.returns else "Any"
)
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p["optional"] and p["default"] is not None:
param_str += f"={p['default']}"
elif p["optional"]:
param_str += "=None"
if p["kind"] != "positional":
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
# Simple format for backwards compatibility
params_list = [f"{p['name']}:{p['type']}" for p in params]
# Backwards compatibility args list
arg_list = [arg.arg for arg in node.args.args]
functions.append(
{
"name": node.name,
"full_name": f"{module_name}.{node.name}",
"params": params,
"params_detailed": params_detailed,
"params_list": params_list,
"return_type": return_type,
"args": arg_list,
}
)
elif isinstance(node, ast.Import | ast.ImportFrom):
# Track internal imports only
if isinstance(node, ast.Import):
imports.extend(
alias.name
for alias in node.names
if self._is_likely_internal(
alias.name,
project_modules,
)
)
elif (
isinstance(node, ast.ImportFrom)
and node.module
and (
node.module.startswith(".")
or self._is_likely_internal(
node.module,
project_modules,
)
)
):
imports.append(node.module)
return {
"module_name": module_name,
"file_path": relative_path,
"classes": classes,
"functions": functions,
"imports": list(set(imports)), # Remove duplicates
"line_count": len(content.splitlines()),
}
except (SyntaxError, ValueError) as e:
logger.exception("Failed to parse Python file %s", file_path)
msg = f"Python parsing failed for {file_path}: {e}"
raise ParsingError(msg) from e
except ParsingError:
raise
except Exception:
logger.exception("Unexpected error analyzing %s", file_path)
return None
def _extract_class_attributes(
self,
class_node: ast.ClassDef,
) -> list[dict[str, Any]]:
"""Comprehensively extract all class attributes including:
- Instance attributes from __init__ methods (self.attr = value)
- Type annotated attributes in __init__ (self.attr: Type = value)
- Property decorators (@property def attr)
- Class-level attributes (both annotated and non-annotated)
- __slots__ definitions
- Dataclass and attrs field definitions
"""
attributes = []
attribute_stats = {
"total": 0,
"dataclass": 0,
"attrs": 0,
"class_vars": 0,
"properties": 0,
"slots": 0,
}
try:
# Check if class has dataclass or attrs decorators
is_dataclass = self._has_dataclass_decorator(class_node)
is_attrs_class = self._has_attrs_decorator(class_node)
# Extract class-level attributes
for item in class_node.body:
try:
# Type annotated class attributes
if isinstance(item, ast.AnnAssign) and isinstance(
item.target, ast.Name
):
if not item.target.id.startswith("_"):
# Check for ClassVar annotations before
# assuming dataclass/attrs semantics
is_class_var = self._is_class_var_annotation(
item.annotation,
)
# Determine classification based on
# ClassVar and framework
if is_class_var:
# ClassVar always class attribute,
# regardless of framework
is_instance_attr = False
is_class_attr = True
attribute_stats["class_vars"] += 1
elif is_dataclass or is_attrs_class:
# In dataclass/attrs, non-ClassVar
# annotations are instance variables
is_instance_attr = True
is_class_attr = False
if is_dataclass:
attribute_stats["dataclass"] += 1
if is_attrs_class:
attribute_stats["attrs"] += 1
else:
# Regular classes: annotations without
# assignment typically class-level
is_instance_attr = False
is_class_attr = True
type_hint = (
self._get_name(item.annotation)
if item.annotation
else "Any"
)
default_val = (
self._get_default_value(item.value)
if item.value
else None
)
attr_info = {
"name": item.target.id,
"type": type_hint,
"is_instance": is_instance_attr,
"is_class": is_class_attr,
"is_property": False,
"has_type_hint": True,
"default_value": default_val,
"line_number": item.lineno,
"from_dataclass": is_dataclass,
"from_attrs": is_attrs_class,
"is_class_var": is_class_var,
}
attributes.append(attr_info)
attribute_stats["total"] += 1
# Non-annotated class attributes
elif isinstance(item, ast.Assign):
# Check for __slots__
for target in item.targets:
if isinstance(target, ast.Name):
if target.id == "__slots__":
slots = self._extract_slots(item.value)
for slot_name in slots:
if not slot_name.startswith("_"):
attributes.append(
{
"name": slot_name,
"type": "Any",
# Slots are instance attributes
"is_instance": True,
"is_class": False,
"is_property": False,
"has_type_hint": False,
"default_value": None,
"line_number": item.lineno,
"from_slots": True,
"from_dataclass": False,
"from_attrs": False,
"is_class_var": False,
}
)
attribute_stats["slots"] += 1
attribute_stats["total"] += 1
elif not target.id.startswith("_"):
# Regular class attribute
inferred_type = (
self._infer_type_from_value(
item.value,
)
if item.value
else "Any"
)
default_val = (
self._get_default_value(
item.value,
)
if item.value
else None
)
attributes.append(
{
"name": target.id,
"type": inferred_type,
"is_instance": False,
"is_class": True,
"is_property": False,
"has_type_hint": False,
"default_value": default_val,
"line_number": item.lineno,
"from_dataclass": False,
"from_attrs": False,
"is_class_var": False,
}
)
attribute_stats["total"] += 1
# Properties with @property decorator
elif (
isinstance(item, ast.FunctionDef)
and not item.name.startswith("_")
and any(
isinstance(dec, ast.Name) and dec.id == "property"
for dec in item.decorator_list
)
):
return_type = (
self._get_name(item.returns) if item.returns else "Any"
)
# Properties accessed on instances but
# defined at class level
attributes.append(
{
"name": item.name,
"type": return_type,
"is_instance": False,
"is_class": False,
"is_property": True,
"has_type_hint": item.returns is not None,
"default_value": None,
"line_number": item.lineno,
"from_dataclass": False,
"from_attrs": False,
"is_class_var": False,
}
)
attribute_stats["properties"] += 1
attribute_stats["total"] += 1
except AnalysisError as e:
logger.debug(
"Failed to extract attribute from class body: %s",
e,
)
continue
except Exception:
logger.exception(
"Unexpected error extracting attribute",
)
continue
# Extract attributes from __init__ method
# (unless dataclass/attrs class with no __init__)
init_attributes = self._extract_init_attributes(
class_node,
)
for init_attr in init_attributes:
# Ensure init attributes have framework metadata
init_attr.setdefault("from_dataclass", False)
init_attr.setdefault("from_attrs", False)
init_attr.setdefault("is_class_var", False)
attributes.extend(init_attributes)
attribute_stats["total"] += len(init_attributes)
# Enhanced deduplication logic respecting dataclass semantics
unique_attributes = {}
for attr in attributes:
name = attr["name"]
if name not in unique_attributes:
unique_attributes[name] = attr
else:
existing = unique_attributes[name]
should_replace = False
# Priority 1: Dataclass/attrs fields take precedence
attr_is_framework = attr.get("from_dataclass") or attr.get(
"from_attrs"
)
existing_is_framework = existing.get(
"from_dataclass"
) or existing.get("from_attrs")
if attr_is_framework and not existing_is_framework:
should_replace = True
# Priority 2: Type-hinted over non-hinted
# (same framework)
elif (
attr["has_type_hint"]
and not existing["has_type_hint"]
and not should_replace
):
# Not prioritizing dataclass/attrs
if not (existing_is_framework and not attr_is_framework):
should_replace = True
# Priority 3: Instance over class attributes
# (same framework and type hint status)
elif (
attr["is_instance"]
and not existing["is_instance"]
and (attr["has_type_hint"] == existing["has_type_hint"])
and not should_replace
):
# Not prioritizing by framework or
# type hints
if existing_is_framework == attr_is_framework:
should_replace = True
# Priority 4: Properties always kept (unique)
elif attr.get("is_property") and not existing.get("is_property"):
should_replace = True
if should_replace:
unique_attributes[name] = attr
# Log attribute extraction statistics
final_count = len(unique_attributes)
if attribute_stats["total"] > 0:
logger.debug(
"Extracted %s unique attributes from %s: "
"dataclass=%s, attrs=%s, class_vars=%s, "
"properties=%s, slots=%s, total_processed=%s",
final_count,
class_node.name,
attribute_stats["dataclass"],
attribute_stats["attrs"],
attribute_stats["class_vars"],
attribute_stats["properties"],
attribute_stats["slots"],
attribute_stats["total"],
)
return list(unique_attributes.values())
except AnalysisError:
logger.exception(
"Failed to extract class attributes from %s",
class_node.name,
)
return []
except Exception:
logger.exception(
"Unexpected error extracting class attributes from %s",
class_node.name,
)
return []
def _has_dataclass_decorator(self, class_node: ast.ClassDef) -> bool:
"""Check if class has @dataclass decorator"""
try:
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id in ["dataclass", "dataclasses"]:
return True
elif isinstance(decorator, ast.Attribute):
# Handle dataclasses.dataclass
attr_name = self._get_name(decorator)
if "dataclass" in attr_name.lower():
return True
elif isinstance(decorator, ast.Call):
# Handle @dataclass() with parameters
func_name = self._get_name(decorator.func)
if "dataclass" in func_name.lower():
return True
except AnalysisError as e:
logger.debug("Failed to check dataclass decorator: %s", e)
except Exception:
logger.exception("Unexpected error checking dataclass decorator")
return False
def _has_attrs_decorator(self, class_node: ast.ClassDef) -> bool:
"""Check if class has @attrs decorator"""
try:
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id in ["attrs", "attr"]:
return True
elif isinstance(decorator, ast.Attribute):
# Handle attr.s, attrs.define, etc.
attr_name = self._get_name(decorator)
patterns = [
"attr.s",
"attr.define",
"attrs.define",
"attrs.frozen",
]
if any(x in attr_name.lower() for x in patterns):
return True
elif isinstance(decorator, ast.Call):
# Handle @attr.s() with parameters
func_name = self._get_name(decorator.func)
patterns = [
"attr.s",
"attr.define",
"attrs.define",
"attrs.frozen",
]
if any(x in func_name.lower() for x in patterns):
return True
except AnalysisError as e:
logger.debug("Failed to check attrs decorator: %s", e)
except Exception:
logger.exception("Unexpected error checking attrs decorator")
return False
def _is_class_var_annotation(self, annotation_node: Any) -> bool:
"""Check if an annotation is a ClassVar type.
Handles patterns like ClassVar[int], typing.ClassVar[str], etc.
"""
if annotation_node is None:
return False
try:
annotation_str = self._get_name(annotation_node)
except AnalysisError as e:
logger.debug("Failed to check ClassVar annotation: %s", e)
return False
except Exception:
logger.exception("Unexpected error checking ClassVar annotation")
return False
else:
return "ClassVar" in annotation_str
def _extract_init_attributes(
self,
class_node: ast.ClassDef,
) -> list[dict[str, Any]]:
"""Extract attributes from __init__ method"""
attributes: list[dict[str, Any]] = []
# Find __init__ method
init_method = None
for item in class_node.body:
if isinstance(item, ast.FunctionDef) and item.name == "__init__":
init_method = item
break
if not init_method:
return attributes
try:
for node in ast.walk(init_method):
try:
# Handle annotated assignments (e.g.,
# self.attr: Type = value)
if isinstance(node, ast.AnnAssign) and isinstance(
node.target,
ast.Attribute,
):
is_self = (
isinstance(node.target.value, ast.Name)
and node.target.value.id == "self"
)
if is_self and not node.target.attr.startswith(
"_",
):
type_hint = (
self._get_name(node.annotation)
if node.annotation
else "Any"
)
default = (
self._get_default_value(node.value)
if node.value
else None
)
attributes.append(
{
"name": node.target.attr,
"type": type_hint,
"is_instance": True,
"is_class": False,
"is_property": False,
"has_type_hint": True,
"default_value": default,
"line_number": node.lineno,
}
)
# Handle regular assignments: self.attr = value
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Attribute):
is_self = (
isinstance(
target.value,
ast.Name,
)
and target.value.id == "self"
)
if is_self and not target.attr.startswith(
"_",
):
# Infer type from assignment
inferred = self._infer_type_from_value(
node.value,
)
default = self._get_default_value(
node.value,
)
attributes.append(
{
"name": target.attr,
"type": inferred,
"is_instance": True,
"is_class": False,
"is_property": False,
"has_type_hint": False,
"default_value": default,
"line_number": node.lineno,
}
)
# Handle multiple assignments (e.g.,
# self.x = self.y = value)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
is_self = (
isinstance(elt, ast.Attribute)
and isinstance(
elt.value,
ast.Name,
)
and elt.value.id == "self"
)
if is_self and not cast(
ast.Attribute, elt
).attr.startswith(
"_",
):
inferred = self._infer_type_from_value(
node.value,
)
default = self._get_default_value(
node.value,
)
attributes.append(
{
"name": cast(ast.Attribute, elt).attr,
"type": inferred,
"is_instance": True,
"is_class": False,
"is_property": False,
"has_type_hint": False,
"default_value": default,
"line_number": node.lineno,
}
)
except AnalysisError as e:
logger.debug("Failed to extract __init__ attribute: %s", e)
continue
except Exception:
logger.exception("Unexpected error extracting __init__ attribute")
continue
except AnalysisError as e:
logger.debug("Failed to walk __init__ method: %s", e)
except Exception:
logger.exception("Unexpected error walking __init__ method")
return attributes
def _extract_slots(self, slots_node: Any) -> list[str]:
"""Extract slot names from __slots__ definition"""
slots: list[str] = []
try:
if isinstance(slots_node, ast.List | ast.Tuple):
for elt in slots_node.elts:
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
slots.append(elt.value)
elif isinstance(elt, ast.Str) and isinstance(elt.s, str):
# Python < 3.8 compatibility
slots.append(elt.s)
elif isinstance(slots_node, ast.Constant) and isinstance(
slots_node.value, str
):
slots.append(slots_node.value)
elif isinstance(slots_node, ast.Str) and isinstance(slots_node.s, str):
# Python < 3.8 compatibility
slots.append(slots_node.s)
except AnalysisError as e:
logger.debug("Failed to extract slots: %s", e)
except Exception:
logger.exception("Unexpected error extracting slots")
return slots
def _infer_type_from_value(self, value_node: Any) -> str:
"""Attempt to infer type from assignment value with enhanced patterns"""
try:
if isinstance(value_node, ast.Constant):
if isinstance(value_node.value, bool):
return "bool"
if isinstance(value_node.value, int):
return "int"
if isinstance(value_node.value, float):
return "float"
if isinstance(value_node.value, str):
return "str"
if isinstance(value_node.value, bytes):
return "bytes"
if value_node.value is None:
return "Optional[Any]"
elif isinstance(value_node, ast.List | ast.ListComp):
return "List[Any]"
elif isinstance(value_node, ast.Dict | ast.DictComp):
return "Dict[Any, Any]"
elif isinstance(value_node, ast.Set | ast.SetComp):
return "Set[Any]"
elif isinstance(value_node, ast.Tuple):
return "Tuple[Any, ...]"
elif isinstance(value_node, ast.Call):
# Try to get type from function call
func_name = self._get_name(value_node.func)
builtins = [
"list",
"dict",
"set",
"tuple",
"str",
"int",
"float",
"bool",
]
if func_name in builtins:
return func_name
if func_name in ["defaultdict", "Counter", "OrderedDict"]:
return f"collections.{func_name}"
if func_name in ["deque"]:
return "collections.deque"
if func_name in ["Path"]:
return "pathlib.Path"
if func_name in ["datetime", "date", "time"]:
return f"datetime.{func_name}"
if func_name in ["UUID"]:
return "uuid.UUID"
if func_name in ["re.compile", "compile"]:
return "re.Pattern"
# Handle dataclass/attrs field calls
if "field" in func_name.lower():
return "Any" # Field type should be inferred from annotation
return "Any" # Unknown function call
elif isinstance(value_node, ast.Attribute):
# Handle attribute access like self.other_attr, module.CONSTANT
attr_name = self._get_name(value_node)
if "CONSTANT" in attr_name.upper() or attr_name.isupper():
return "Any" # Constants could be anything
return "Any"
elif isinstance(value_node, ast.Name):
# Handle variable references
if value_node.id in ["True", "False"]:
return "bool"
if value_node.id in ["None"]:
return "Optional[Any]"
return "Any" # Could be any variable
elif isinstance(value_node, ast.BinOp):
# Handle binary operations - try to infer from operands
return "Any" # Could be various types depending on operation
except AnalysisError as e:
logger.debug("Failed to infer type: %s", e)
except Exception:
logger.exception("Unexpected error in type inference")
return "Any"
def _is_likely_internal(self, import_name: str, project_modules: set[str]) -> bool:
"""Check if an import is likely internal to the project"""
if not import_name:
return False
# Relative imports are definitely internal
if import_name.startswith("."):
return True
# Check if it's a known external module
base_module = import_name.split(".")[0]
if base_module in self.external_modules:
return False
# Check if it matches any project module
for project_module in project_modules:
if import_name.startswith(project_module):
return True
# If it's not obviously external, consider it internal
min_module_length = 2
test_keywords = ["test", "mock", "fake"]
is_not_test = not any(ext in base_module.lower() for ext in test_keywords)
is_not_private = not base_module.startswith("_")
is_long_enough = len(base_module) > min_module_length
return bool(is_not_test and is_not_private and is_long_enough)
def _get_importable_module_name(
self,
repo_root: Path,
relative_path: str,
) -> str:
"""Determine the actual importable module name for a Python file"""
# Start with the default: convert file path to module path
default_module = (
relative_path.replace("/", ".").replace("\\", ".").replace(".py", "")
)
# Common patterns to detect the actual package root
path_parts = Path(relative_path).parts
# Look for common package indicators
package_roots = []
# Check each directory level for __init__.py to find package boundaries
current_path = repo_root
for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself
current_path = current_path / part
if (current_path / "__init__.py").exists():
# This is a package directory, mark it as a potential root
package_roots.append(i)
if package_roots:
# Use the first (outermost) package as the root
package_start = package_roots[0]
module_parts = path_parts[package_start:]
return ".".join(module_parts).replace(".py", "")
# Fallback: look for common Python project structures
# Skip common non-package directories
skip_dirs = {
"src",
"lib",
"source",
"python",
"pkg",
"packages",
}
# Find the first directory that's not in skip_dirs
filtered_parts: list[str] = []
for part in path_parts:
# Once included, include everything
if part.lower() not in skip_dirs or filtered_parts:
filtered_parts.append(part)
if filtered_parts:
return ".".join(filtered_parts).replace(".py", "")
# Final fallback: use the default
return default_module
def _extract_function_parameters(
self,
func_node: Any,
) -> list[dict[str, Any]]:
"""Comprehensive parameter extraction from function def"""
params: list[dict[str, Any]] = []
# Regular positional arguments
for i, arg in enumerate(func_node.args.args):
if arg.arg == "self":
continue
param_type = self._get_name(arg.annotation) if arg.annotation else "Any"
param_info = {
"name": arg.arg,
"type": param_type,
"kind": "positional",
"optional": False,
"default": None,
}
# Check if argument has a default value
num_args = len(func_node.args.args)
num_defaults = len(func_node.args.defaults)
defaults_start = num_args - num_defaults
if i >= defaults_start:
default_idx = i - defaults_start
if default_idx < num_defaults:
param_info["optional"] = True
default_node = func_node.args.defaults[default_idx]
param_info["default"] = self._get_default_value(default_node)
params.append(param_info)
# *args parameter
if func_node.args.vararg:
vararg_type = (
self._get_name(
func_node.args.vararg.annotation,
)
if func_node.args.vararg.annotation
else "Any"
)
params.append(
{
"name": f"*{func_node.args.vararg.arg}",
"type": vararg_type,
"kind": "var_positional",
"optional": True,
"default": None,
}
)
# Keyword-only arguments (after *)
for i, arg in enumerate(func_node.args.kwonlyargs):
param_type = self._get_name(arg.annotation) if arg.annotation else "Any"
param_info = {
"name": arg.arg,
"type": param_type,
"kind": "keyword_only",
"optional": True,
"default": None,
}
# Check for default value
has_default = (
i < len(func_node.args.kw_defaults)
and func_node.args.kw_defaults[i] is not None
)
if has_default:
default_node = func_node.args.kw_defaults[i]
param_info["default"] = self._get_default_value(default_node)
else:
# No default = required kwonly arg
param_info["optional"] = False
params.append(param_info)
# **kwargs parameter
if func_node.args.kwarg:
kwarg_type = (
self._get_name(
func_node.args.kwarg.annotation,
)
if func_node.args.kwarg.annotation
else "Dict[str, Any]"
)
params.append(
{
"name": f"**{func_node.args.kwarg.arg}",
"type": kwarg_type,
"kind": "var_keyword",
"optional": True,
"default": None,
}
)
return params
def _get_default_value(self, default_node: Any) -> str:
"""Extract default value from AST node"""
try:
if isinstance(default_node, ast.Constant):
return repr(default_node.value)
if isinstance(default_node, ast.Name):
return default_node.id
if isinstance(default_node, ast.Attribute):
return self._get_name(default_node)
if isinstance(default_node, ast.List):
return "[]"
if isinstance(default_node, ast.Dict):
return "{}"
except AnalysisError:
return "..."
except Exception:
logger.exception("Unexpected error extracting default value")
return "..."
else:
return "..."
def _get_name(self, node: Any) -> str:
"""Extract name from AST node, handling complex types safely"""
if node is None:
return "Any"
try:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
if hasattr(node, "value"):
return f"{self._get_name(node.value)}.{node.attr}"
return node.attr
if isinstance(node, ast.Subscript):
# Handle List[Type], Dict[K,V], etc.
base = self._get_name(node.value)
if hasattr(node, "slice"):
if isinstance(node.slice, ast.Name):
return f"{base}[{node.slice.id}]"
if isinstance(node.slice, ast.Tuple):
elts = [self._get_name(elt) for elt in node.slice.elts]
return f"{base}[{', '.join(elts)}]"
if isinstance(node.slice, ast.Constant):
return f"{base}[{node.slice.value!r}]"
if isinstance(node.slice, ast.Attribute | ast.Subscript):
return f"{base}[{self._get_name(node.slice)}]"
# Try to get the name of the slice, fallback to Any
try:
slice_name = self._get_name(node.slice)
except Exception:
return f"{base}[Any]"
else:
return f"{base}[{slice_name}]"
return base
if isinstance(node, ast.Constant):
return str(node.value)
if isinstance(node, ast.Str): # Python < 3.8
return f'"{node.s!r}"'
if isinstance(node, ast.Tuple):
elts = [self._get_name(elt) for elt in node.elts]
return f"({', '.join(elts)})"
if isinstance(node, ast.List):
elts = [self._get_name(elt) for elt in node.elts]
return f"[{', '.join(elts)}]"
except AnalysisError:
return "Any"
except Exception:
logger.exception("Unexpected error extracting name from AST node")
return "Any"
else:
# Fallback for complex types - return a simple string representation
return "Any"