"""Code processing integration between scanner and parser."""
from __future__ import annotations
import ast
import asyncio
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, cast
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.future import select
from src.database.models import Class, File, Function, Import, Module
from src.database.session_manager import ParallelSessionManager # noqa: F401
from src.domain.indexer import DomainIndexer
from src.logger import get_logger
from src.parser.code_extractor import CodeExtractor
from src.parser.parser_factory import ParserFactory
from src.parser.reference_analyzer import ReferenceAnalyzer
logger = get_logger(__name__)
class CodeProcessor:
"""Process code files to extract and store code entities."""
def __init__(
self,
db_session: AsyncSession,
*,
repository_path: Path | str | None = None,
enable_domain_analysis: bool = False,
enable_parallel: bool = False,
) -> None:
self.db_session = db_session
self.repository_path = Path(repository_path) if repository_path else None
self.code_extractor = CodeExtractor()
self.parser_factory = ParserFactory()
self.enable_domain_analysis = enable_domain_analysis
self.domain_indexer = None
self._use_parallel = enable_parallel
if enable_domain_analysis:
self.domain_indexer = DomainIndexer(db_session)
async def process_file(self, file_record: File) -> dict[str, Any]:
"""Process a file to extract code entities."""
logger.info("Processing file: %s", file_record.path)
# Check if file is supported
file_path = Path(file_record.path)
# Get absolute path if repository path is set
if self.repository_path:
absolute_path = self.repository_path / file_path
else:
absolute_path = file_path
if not self.parser_factory.is_supported(file_path):
logger.debug("Skipping unsupported file type: %s", file_path)
return {
"file_id": cast("int", file_record.id),
"status": "skipped",
"reason": "unsupported_file_type",
}
try:
# Extract entities
entities = await self._extract_entities(
absolute_path, cast("int", file_record.id)
)
if not entities:
return {
"file_id": cast("int", file_record.id),
"status": "failed",
"reason": "extraction_failed",
}
# Store entities in database
store_stats = await self._store_entities(entities, file_record)
# Extract and store references
ref_stats = await self.extract_and_store_references(file_record, entities)
# Build full statistics with references included (different type)
stats: dict[str, Any] = {**store_stats, "references": ref_stats}
# Update file processing status
cast("Any", file_record).last_modified = datetime.now(UTC).replace(
tzinfo=None
)
await self.db_session.commit()
# Run domain analysis if enabled
domain_stats: dict[str, int] = {}
# Check if language supports domain analysis (object-oriented languages)
from src.parser.plugin_registry import LanguagePluginRegistry
plugin = LanguagePluginRegistry.get_plugin_by_file_path(file_path)
supports_domain_analysis = (
plugin
and plugin.supports_feature("classes")
and plugin.supports_feature("functions")
)
if self.domain_indexer and supports_domain_analysis:
try:
logger.info("Running domain analysis for %s", file_record.path)
domain_result = await self.domain_indexer.index_file(
cast("int", file_record.id)
)
domain_stats = {
"domain_entities": cast(
"int", domain_result.get("entities_extracted", 0)
),
"domain_relationships": cast(
"int",
domain_result.get(
"relationships_extracted",
0,
),
),
}
except (ValueError, TypeError, AttributeError) as e:
logger.warning(
"Domain analysis failed for %s: %s",
file_record.path,
e,
)
return {
"file_id": cast("int", file_record.id),
"status": "success",
"statistics": stats,
"domain_statistics": domain_stats,
}
except Exception as e:
logger.exception("Error processing file %s", file_record.path)
import traceback
logger.debug("Traceback: %s", traceback.format_exc())
# Update file with error
await self.db_session.commit()
return {
"file_id": cast("int", file_record.id),
"status": "failed",
"reason": "processing_error",
"error": str(e),
}
async def _extract_entities(
self,
file_path: Path,
file_id: int,
) -> dict[str, list[Any]] | None:
"""Extract code entities from a file."""
# Run extraction in thread pool to avoid blocking
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self.code_extractor.extract_from_file,
file_path,
file_id,
)
async def _store_entities(
self,
entities: dict[str, list[Any]],
file_record: File,
) -> dict[str, int]:
"""Store extracted entities in the database."""
stats: dict[str, int] = {
"modules": 0,
"classes": 0,
"functions": 0,
"imports": 0,
}
# Clear existing entities for this file
await self._clear_file_entities(cast("int", file_record.id))
# Store modules and track their IDs
# For Python files, there's typically one module per file
module_id: int | None = None
module_map: dict[str, int] = {}
for module_data in entities.get("modules", []):
module = Module(
file_id=cast("int", file_record.id),
name=module_data["name"],
docstring=module_data.get("docstring"),
start_line=module_data["start_line"],
end_line=module_data["end_line"],
)
self.db_session.add(module)
await self.db_session.flush() # Get ID
# In tests, the mock might not assign an ID, so use a fallback
if not hasattr(module, "id") or cast("Any", module).id is None:
cast("Any", module).id = 1 # Default ID for tests
module_map[module_data["name"]] = cast("int", cast("Any", module).id)
# For single-module files, track the primary module ID
if module_id is None:
module_id = cast("int", cast("Any", module).id)
stats["modules"] += 1
# Store classes with proper module ID
class_map: dict[str, int] = {}
for class_data in entities.get("classes", []):
# Skip if no module ID is available (shouldn't happen for Python files)
if module_id is None:
logger.warning(
"No module ID available for class %s in file %s",
class_data["name"],
file_record.path,
)
continue
# Use the module_id from the stored module
class_obj = Class(
module_id=module_id, # Use the actual module ID
name=class_data["name"],
docstring=class_data.get("docstring"),
base_classes=class_data.get("base_classes", []),
decorators=class_data.get("decorators", []),
start_line=class_data["start_line"],
end_line=class_data["end_line"],
is_abstract=class_data.get("is_abstract", False),
)
self.db_session.add(class_obj)
await self.db_session.flush() # Get ID
class_map[class_data["name"]] = cast("int", cast("Any", class_obj).id)
stats["classes"] += 1
# Store functions and methods with proper module ID
for func_data in entities.get("functions", []):
# Skip if no module ID is available (shouldn't happen for Python files)
if module_id is None:
logger.warning(
"No module ID available for function %s in file %s",
func_data["name"],
file_record.path,
)
continue
class_name = func_data.get("class_name")
function = Function(
module_id=module_id, # Use the actual module ID
class_id=class_map.get(class_name) if class_name else None,
name=func_data["name"],
parameters=func_data.get("parameters", []),
return_type=cast("str | None", func_data.get("return_type")),
docstring=func_data.get("docstring"),
decorators=func_data.get("decorators", []),
is_async=func_data.get("is_async", False),
is_generator=func_data.get("is_generator", False),
is_property=func_data.get("is_property", False),
is_static=func_data.get("is_staticmethod", False),
is_classmethod=func_data.get("is_classmethod", False),
start_line=func_data["start_line"],
end_line=func_data["end_line"],
complexity=func_data.get("complexity", 1),
)
self.db_session.add(function)
stats["functions"] += 1
# Store imports
for import_data in entities.get("imports", []):
import_obj = Import(
file_id=cast("int", file_record.id),
import_statement=import_data["import_statement"],
module_name=import_data.get("imported_from"),
imported_names=import_data.get("imported_names", []),
is_relative=import_data.get("is_relative", False),
level=import_data.get("level", 0),
line_number=import_data["line_number"],
)
self.db_session.add(import_obj)
stats["imports"] += 1
await self.db_session.commit()
return stats
async def _clear_file_entities(self, file_id: int) -> None:
"""Clear existing entities for a file before re-parsing."""
# Delete in correct order to respect foreign key constraints
await self.db_session.execute(delete(Import).where(Import.file_id == file_id))
# Functions are linked to modules, not files directly
# First get modules for this file
modules_result = await self.db_session.execute(
select(Module).where(Module.file_id == file_id),
)
modules = modules_result.scalars().all()
module_ids = [cast("int", cast("Any", m).id) for m in modules]
if module_ids:
# Delete functions linked to these modules
await self.db_session.execute(
delete(Function).where(Function.module_id.in_(module_ids)),
)
# Delete classes linked to these modules
await self.db_session.execute(
delete(Class).where(Class.module_id.in_(module_ids)),
)
await self.db_session.execute(delete(Module).where(Module.file_id == file_id))
await self.db_session.commit()
async def _process_files_sequential(
self, file_records: list[File]
) -> list[dict[str, Any] | Exception]:
"""Process files sequentially."""
results: list[dict[str, Any] | Exception] = []
for file in file_records:
try:
result = await self.process_file(file)
results.append(result)
except Exception as e:
logger.exception("Error processing file %s", file.path)
results.append(e)
return results
async def _process_files_parallel(
self, file_records: list[File]
) -> list[dict[str, Any]]:
"""Process files in parallel using separate sessions."""
# Resolve async_sessionmaker and ParallelSessionManager dynamically so tests can
# patch either the original modules (sqlalchemy.ext.asyncio / src.database.session_manager)
# or the names on this module (src.scanner.code_processor).
from importlib import import_module
# Helper to detect mocks
def _is_mock(obj: object) -> bool:
return obj is not None and (
hasattr(obj, "assert_called") or obj.__class__.__name__.endswith("Mock")
)
def _choose_sessionmaker() -> Any:
module_async_sessionmaker = globals().get("async_sessionmaker")
try:
orig_async_sessionmaker = import_module(
"sqlalchemy.ext.asyncio"
).async_sessionmaker
except (ImportError, AttributeError):
orig_async_sessionmaker = None
if _is_mock(module_async_sessionmaker):
return module_async_sessionmaker
if _is_mock(orig_async_sessionmaker):
return orig_async_sessionmaker
return (
module_async_sessionmaker
or orig_async_sessionmaker
or async_sessionmaker
)
def _choose_parallel_pm() -> Any:
module_parallel_pm = globals().get("ParallelSessionManager")
try:
orig_parallel_pm = import_module(
"src.database.session_manager"
).ParallelSessionManager
except (ImportError, AttributeError):
orig_parallel_pm = None
if _is_mock(module_parallel_pm):
return module_parallel_pm
if _is_mock(orig_parallel_pm):
return orig_parallel_pm
return module_parallel_pm or orig_parallel_pm
# Create a session factory from the current session's bind
chosen_async_sessionmaker = _choose_sessionmaker()
bind = self.db_session.bind
session_maker_callable: Any = chosen_async_sessionmaker
session_factory = cast("Any", session_maker_callable)(
bind, expire_on_commit=False
)
# Resolve and create ParallelSessionManager
parallel_session_manager_cls = _choose_parallel_pm()
if parallel_session_manager_cls is None:
raise RuntimeError
parallel_manager = cast("Any", parallel_session_manager_cls)(session_factory)
# Define processing function for parallel execution
async def process_file_with_session(
file_record: File, session: AsyncSession
) -> dict[str, Any]:
# Create a new processor with the session
processor = CodeProcessor(
session,
repository_path=self.repository_path,
enable_domain_analysis=self.enable_domain_analysis,
)
return await processor.process_file(file_record)
# Process files in parallel
batch_size = min(10, max(2, len(file_records) // 4)) # Adaptive batch size
logger.info("Processing files in parallel with batch size: %d", batch_size)
results: list[dict[str, Any] | None] = await parallel_manager.execute_parallel(
file_records, process_file_with_session, batch_size=batch_size
)
# Convert None results to error dictionaries
processed_results: list[dict[str, Any]] = []
for i, result in enumerate(results):
if result is None:
processed_results.append(
{
"file_id": cast("int", file_records[i].id),
"status": "failed",
"reason": "parallel_processing_error",
"error": "Unknown error during parallel processing",
}
)
else:
processed_results.append(result)
return processed_results
async def process_files(self, file_records: list[File]) -> dict[str, Any]:
"""Process multiple files."""
logger.info("Processing %s files", len(file_records))
# Decide processing strategy
results: Sequence[dict[str, Any] | Exception]
# Check if we should use parallel processing
if (
len(file_records) > 5
and hasattr(self, "_use_parallel")
and self._use_parallel
):
results = await self._process_files_parallel(file_records)
else:
# Process files sequentially for small batches or when parallel is disabled
results = await self._process_files_sequential(file_records)
# Aggregate results
summary: dict[str, Any] = {
"total": len(file_records),
"success": 0,
"failed": 0,
"skipped": 0,
"errors": [],
"statistics": {
"modules": 0,
"classes": 0,
"functions": 0,
"imports": 0,
"references": 0,
},
}
for i, result in enumerate(results):
if isinstance(result, Exception):
summary["failed"] += 1
cast("list[dict[str, Any]]", summary["errors"]).append(
{
"file": file_records[i].path,
"error": str(result),
},
)
elif isinstance(result, dict):
if result["status"] == "success":
summary["success"] += 1
# Aggregate statistics
for key, value in cast(
"dict[str, Any]", result.get("statistics", {})
).items():
if key == "references" and isinstance(value, dict):
# Handle nested reference statistics
summary["statistics"]["references"] += cast(
"int", value.get("total", 0)
)
else:
summary["statistics"][key] += cast("int", value)
elif result["status"] == "skipped":
summary["skipped"] += 1
else:
summary["failed"] += 1
if "error" in result:
cast("list[dict[str, Any]]", summary["errors"]).append(
{
"file": file_records[i].path,
"error": result["error"],
},
)
logger.info(
"Processing complete: %s success, %s failed, %s skipped",
summary["success"],
summary["failed"],
summary["skipped"],
)
return summary
async def get_file_structure(self, file_record: File) -> dict[str, Any]:
"""Get the parsed structure of a file."""
structure: dict[str, Any] = {
"file": {
"id": cast("int", file_record.id),
"path": file_record.path,
"language": file_record.language,
},
"modules": [],
"classes": [],
"functions": [],
"imports": [],
}
# Get modules
modules_result = await self.db_session.execute(
select(Module).where(Module.file_id == cast("int", file_record.id)),
)
modules_struct: list[dict[str, Any]] = [
{
"id": cast("int", cast("Any", m).id),
"name": m.name,
"docstring": m.docstring,
"lines": f"{m.start_line}-{m.end_line}",
}
for m in modules_result.scalars()
]
structure["modules"] = modules_struct
# Get module IDs for related queries
module_ids = [cast("int", m["id"]) for m in modules_struct]
# Get classes through modules
if module_ids:
classes_result = await self.db_session.execute(
select(Class).where(Class.module_id.in_(module_ids)),
)
else:
classes_result = None
classes_struct: list[dict[str, Any]] = [
{
"id": cast("int", cast("Any", c).id),
"name": c.name,
"docstring": c.docstring,
"base_classes": c.base_classes,
"is_abstract": c.is_abstract,
"lines": f"{c.start_line}-{c.end_line}",
}
for c in (classes_result.scalars() if classes_result else [])
]
structure["classes"] = classes_struct
# Get functions through modules
if module_ids:
functions_result = await self.db_session.execute(
select(Function).where(Function.module_id.in_(module_ids)),
)
else:
functions_result = None
functions_struct: list[dict[str, Any]] = [
{
"id": cast("int", cast("Any", f).id),
"name": f.name,
"class_id": f.class_id,
"parameters": f.parameters,
"return_type": f.return_type,
"is_async": f.is_async,
"lines": f"{f.start_line}-{f.end_line}",
}
for f in (functions_result.scalars() if functions_result else [])
]
structure["functions"] = functions_struct
# Get imports
imports_result = await self.db_session.execute(
select(Import).where(Import.file_id == cast("int", file_record.id)),
)
imports_struct: list[dict[str, Any]] = [
{
"id": cast("int", cast("Any", i).id),
"statement": i.import_statement,
"from": i.module_name,
"names": i.imported_names,
"line": i.line_number,
}
for i in imports_result.scalars()
]
structure["imports"] = imports_struct
return structure
async def extract_and_store_references(
self,
file_record: File,
entities: dict[str, Any],
) -> dict[str, Any]:
"""Extract and store code references for a file.
Args:
file_record: File database record
entities: Extracted entities with their database IDs
Returns:
Statistics about extracted references
"""
logger.info("Extracting references for file %s", file_record.path)
stats: dict[str, Any] = {
"total": 0,
"imports": 0,
"calls": 0,
"inherits": 0,
"type_hints": 0,
"errors": [],
}
try:
# Read the file content
file_path = (
Path(self.repository_path) / file_record.path
if self.repository_path
else Path(file_record.path)
)
if not file_path.exists():
logger.warning("File not found for reference analysis: %s", file_path)
return stats
content = file_path.read_text()
tree = ast.parse(content)
# Get module path from entities
module_info = entities.get("modules", [{}])[0]
module_path = module_info.get(
"name", file_record.path.replace("/", ".").replace(".py", "")
)
# Analyze references
analyzer = ReferenceAnalyzer(module_path, file_path)
raw_references = analyzer.analyze(tree)
# Convert raw references to database references
# This requires looking up entity IDs from names
db_references = await self._resolve_references(
raw_references,
file_record,
entities,
)
# Store references in database
if db_references:
from src.database.repositories import CodeReferenceRepo
ref_repo = CodeReferenceRepo(self.db_session)
await ref_repo.bulk_create(db_references)
# Update statistics
stats["total"] = len(db_references)
for ref in db_references:
ref_type = cast("str", ref.get("reference_type", ""))
if ref_type == "import":
stats["imports"] = cast("int", stats["imports"]) + 1
elif ref_type == "call":
stats["calls"] = cast("int", stats["calls"]) + 1
elif ref_type == "inherit":
stats["inherits"] = cast("int", stats["inherits"]) + 1
elif ref_type == "type_hint":
stats["type_hints"] = cast("int", stats["type_hints"]) + 1
logger.info(
"Extracted %d references: %d imports, %d calls, %d inherits, %d type hints",
stats["total"],
stats["imports"],
stats["calls"],
stats["inherits"],
stats["type_hints"],
)
except Exception as e:
logger.exception("Error extracting references for %s", file_record.path)
cast("list[str]", stats["errors"]).append(str(e))
return stats
def _build_entity_map(
self, entities: list[Any], key: str = "name"
) -> dict[str, int]:
"""Build a mapping of entity names to IDs."""
entity_map: dict[str, int] = {}
for e in entities:
if isinstance(e, dict) and key in e and "id" in e:
entity_map[cast("str", e[key])] = cast("int", e["id"])
return entity_map
async def _load_entity_maps_from_db(
self, file_id: int
) -> tuple[dict[str, int], dict[str, int], dict[str, int]]:
"""Load entity maps from database for a file."""
from sqlalchemy import select
from src.database.models import Class, Function, Module
# Get modules
result = await self.db_session.execute(
select(Module).where(Module.file_id == file_id)
)
modules = result.scalars().all()
module_map: dict[str, int] = {cast("str", m.name): m.id for m in modules}
# Get classes (through modules)
result = await self.db_session.execute(
select(Class).join(Module).where(Module.file_id == file_id)
)
classes = result.scalars().all()
class_map: dict[str, int] = {cast("str", c.name): c.id for c in classes}
# Get functions (through modules)
result = await self.db_session.execute(
select(Function).join(Module).where(Module.file_id == file_id)
)
functions = result.scalars().all()
function_map: dict[str, int] = {cast("str", f.name): f.id for f in functions}
return module_map, class_map, function_map
def _resolve_source_entity_id(
self,
source_name: str,
source_type: str,
module_map: dict[str, int],
class_map: dict[str, int],
function_map: dict[str, int],
) -> int | None:
"""Resolve source entity name to ID."""
if source_type == "module":
return next(iter(module_map.values())) if module_map else None
if source_type == "class":
class_name = source_name.split(".")[-1]
return class_map.get(class_name)
if source_type == "function":
func_name = source_name.split(".")[-1]
return function_map.get(func_name)
return None
async def _process_single_reference(
self,
ref: dict[str, Any],
file_record: File,
module_map: dict[str, int],
class_map: dict[str, int],
function_map: dict[str, int],
) -> dict[str, Any] | None:
"""Process a single reference and return resolved data."""
try:
source_id = self._resolve_source_entity_id(
cast("str", ref["source_name"]),
cast("str", ref["source_type"]),
module_map,
class_map,
function_map,
)
if not source_id:
logger.debug(
"Could not resolve source entity: %s (%s)",
ref["source_name"],
ref["source_type"],
)
return None
# Resolve target entity
target_id, target_file_id = await self._resolve_target_entity(
cast("str", ref["target_name"]),
cast("str", ref["target_type"]),
)
if not target_id:
logger.debug(
"Could not resolve target entity: %s (%s)",
ref["target_name"],
ref["target_type"],
)
return None
return {
"source_type": ref["source_type"],
"source_id": source_id,
"source_file_id": cast("int", file_record.id),
"source_line": ref.get("source_line"),
"target_type": ref["target_type"],
"target_id": target_id,
"target_file_id": target_file_id,
"reference_type": ref["reference_type"],
"context": cast("str", ref.get("context", ""))[:500],
}
except (KeyError, AttributeError, ValueError) as e:
logger.debug("Error resolving reference: %s - %s", ref, e)
return None
async def _resolve_references(
self,
raw_references: list[dict[str, Any]],
file_record: File,
entities: dict[str, Any],
) -> list[dict[str, Any]]:
"""Resolve reference names to database entity IDs.
Args:
raw_references: Raw references from analyzer
file_record: Source file record
entities: Entities from the current file
Returns:
List of references ready for database storage
"""
# Build lookup maps for current file entities
module_map = self._build_entity_map(entities.get("modules", []))
class_map = self._build_entity_map(entities.get("classes", []))
function_map = self._build_entity_map(entities.get("functions", []))
# Load from DB if needed
if not module_map and not class_map and not function_map:
module_map, class_map, function_map = await self._load_entity_maps_from_db(
cast("int", file_record.id)
)
# Process references
resolved: list[dict[str, Any]] = []
for ref in raw_references:
resolved_ref = await self._process_single_reference(
ref, file_record, module_map, class_map, function_map
)
if resolved_ref:
resolved.append(resolved_ref)
return resolved
async def _resolve_target_entity(
self,
target_name: str,
target_type: str,
) -> tuple[int | None, int | None]:
"""Resolve a target entity name to its database ID.
Args:
target_name: Full entity name (e.g., 'src.parser.base_parser.BaseParser')
target_type: Entity type (module, class, function)
Returns:
Tuple of (entity_id, file_id) or (None, None) if not found
"""
# For now, use simple lookups - in production, this would use a name index
if target_type == "module":
# Look up module by name
result = await self.db_session.execute(
select(Module).where(Module.name == target_name)
)
module = result.scalar_one_or_none()
if module:
return cast("int", cast("Any", module).id), cast(
"int", cast("Any", module).file_id
)
elif target_type == "class":
# Extract class name from full path
class_name = target_name.split(".")[-1]
result = await self.db_session.execute(
select(Class).where(Class.name == class_name)
)
cls = result.scalar_one_or_none()
if cls:
# Get file ID through module
module_result = await self.db_session.execute(
select(Module).where(Module.id == cast("Any", cls).module_id)
)
module = module_result.scalar_one_or_none()
if module:
return cast("int", cast("Any", cls).id), cast(
"int", cast("Any", module).file_id
)
elif target_type == "function":
# Extract function name from full path
func_name = target_name.split(".")[-1]
result = await self.db_session.execute(
select(Function).where(Function.name == func_name)
)
func = result.scalar_one_or_none()
if func:
# Get file ID through module
module_result = await self.db_session.execute(
select(Module).where(Module.id == cast("Any", func).module_id)
)
module = module_result.scalar_one_or_none()
if module:
return cast("int", cast("Any", func).id), cast(
"int", cast("Any", module).file_id
)
return None, None
def enable_parallel_processing(self, enabled: bool = True) -> None:
"""Enable or disable parallel processing for multiple files."""
self._use_parallel = enabled
logger.info("Parallel processing %s", "enabled" if enabled else "disabled")