"""Service for managing embeddings in the database."""
from datetime import UTC, datetime
from typing import Any, cast
from sqlalchemy import and_, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from src.database.domain_models import DomainEntity
from src.database.models import (
Class,
CodeEmbedding,
File,
Function,
Module,
)
from src.embeddings.embedding_generator import EmbeddingGenerator
from src.logger import get_logger
from src.utils.exceptions import EmbeddingError, NotFoundError
logger = get_logger(__name__)
class EmbeddingService:
"""Service for creating and managing code embeddings."""
def __init__(
self,
db_session: AsyncSession,
) -> None:
"""Initialize embedding service.
Args:
db_session: Database session
"""
self.db_session = db_session
self.embedding_generator = EmbeddingGenerator()
async def create_file_embeddings(self, file_id: int) -> dict[str, Any]:
"""Create embeddings for all entities in a file.
Args:
file_id: File ID to process
Returns:
Summary of created embeddings
"""
logger.info("Creating embeddings for file %d", file_id)
# Get file record
sql_result = await self.db_session.execute(
select(File).where(File.id == file_id)
)
file_record = sql_result.scalar_one_or_none()
if not file_record:
msg = "File not found"
raise NotFoundError(msg)
stats: dict[str, Any] = {
"file_id": file_id,
"file_path": cast("str", file_record.path),
"modules": 0,
"classes": 0,
"functions": 0,
"total": 0,
"errors": [],
}
try:
# Process modules
modules = await self._get_file_modules(file_id)
for module in modules:
try:
await self._create_module_embedding(
module, cast("str", file_record.path)
)
stats["modules"] += 1
except Exception as e:
logger.exception(
"Failed to create embedding for module %s",
module.id,
)
stats["errors"].append(f"Module {module.name}: {e!s}")
# Process classes
classes = await self._get_file_classes(file_id)
for cls in classes:
try:
await self._create_class_embedding(
cls, cast("str", file_record.path)
)
stats["classes"] += 1
except Exception as e:
logger.exception(
"Failed to create embedding for class %s",
cls.id,
)
stats["errors"].append(f"Class {cls.name}: {e!s}")
# Process functions
functions = await self._get_file_functions(file_id)
for func in functions:
try:
await self._create_function_embedding(
func, cast("str", file_record.path)
)
stats["functions"] += 1
except Exception as e:
logger.exception(
"Failed to create embedding for function %s",
func.id,
)
stats["errors"].append(f"Function {func.name}: {e!s}")
stats["total"] = stats["modules"] + stats["classes"] + stats["functions"]
logger.info(
"Created %s embeddings for file %s (%s errors)",
stats["total"],
file_id,
stats["errors"],
)
except Exception as e:
logger.exception("Failed to create embeddings for file %s", file_id)
msg = "Failed to create embeddings"
raise EmbeddingError(msg) from e
return stats
async def create_repository_embeddings(
self,
repository_id: int,
limit: int | None = None,
) -> dict[str, Any]:
"""Create embeddings for all files in a repository.
Args:
repository_id: Repository ID to process
limit: Optional limit on number of files to process
Returns:
Summary of created embeddings
"""
logger.info("Creating embeddings for repository %d", repository_id)
# Get repository files
query = select(File).where(
and_(
File.repository_id == repository_id,
File.is_deleted.is_(False),
),
)
if limit:
query = query.limit(limit)
sql_result = await self.db_session.execute(query)
files = list(sql_result.scalars().all())
stats: dict[str, Any] = {
"repository_id": repository_id,
"files_processed": 0,
"total_embeddings": 0,
"errors": [],
}
for file_record in files:
try:
file_stats = await self.create_file_embeddings(
cast("int", file_record.id)
)
stats["files_processed"] += 1
stats["total_embeddings"] += file_stats["total"]
if file_stats["errors"]:
stats["errors"].extend(
[
f"File {file_record.path}: {err}"
for err in file_stats["errors"]
],
)
except Exception as e:
logger.exception("Failed to process file %s", file_record.id)
stats["errors"].append(f"File {file_record.path}: {e!s}")
logger.info(
"Created %s embeddings for %s files in repository %s",
stats["total_embeddings"],
stats["files_processed"],
repository_id,
)
return stats
async def _create_module_embedding(self, module: Module, file_path: str) -> int:
"""Create embedding for a module.
Args:
module: Module record
file_path: Path to the source file
Returns:
Created embedding ID
"""
# Check if embedding already exists
existing = await self._get_existing_embedding("module", cast("int", module.id))
if existing:
logger.debug("Embedding already exists for module %d", module.id)
# mypy: CodeEmbedding.id is Column[int]; cast to runtime int
return cast("int", existing.id)
# Prepare module data
module_data = {
"name": module.name,
"docstring": module.docstring,
"start_line": module.start_line,
"end_line": module.end_line,
}
# Get module statistics
stats = await self._get_module_stats(cast("int", module.id))
# Generate embedding
gen_result = await self.embedding_generator.generate_module_embedding(
module_data,
file_path,
stats,
)
# Store embedding
embedding = CodeEmbedding(
entity_type="module",
entity_id=cast("int", module.id),
file_id=module.file_id,
embedding_type="interpreted",
embedding=gen_result["embedding"],
content=gen_result["text"],
tokens=gen_result.get("tokens"),
repo_metadata=gen_result["metadata"],
created_at=datetime.now(tz=UTC),
)
self.db_session.add(embedding)
await self.db_session.commit()
return cast("int", embedding.id)
async def _create_class_embedding(self, cls: Class, file_path: str) -> int:
"""Create embedding for a class.
Args:
cls: Class record
file_path: Path to the source file
Returns:
Created embedding ID
"""
# Check if embedding already exists
existing = await self._get_existing_embedding("class", cast("int", cls.id))
if existing:
logger.debug("Embedding already exists for class %d", cls.id)
return cast("int", existing.id)
# Prepare class data with methods
methods = await self._get_class_methods(cast("int", cls.id))
class_data = {
"name": cls.name,
"docstring": cls.docstring,
"base_classes": cls.base_classes,
"decorators": cls.decorators,
"is_abstract": cls.is_abstract,
"start_line": cls.start_line,
"end_line": cls.end_line,
"methods": [
{
"name": method.name,
"parameters": method.parameters,
"return_type": method.return_type,
}
for method in methods
],
}
# Generate embedding
results = await self.embedding_generator.generate_class_embeddings(
[class_data],
file_path,
)
if not results or not results[0].get("embedding"):
msg = "Failed to generate class embedding"
raise EmbeddingError(msg)
result = results[0]
# Store embedding
embedding = CodeEmbedding(
entity_type="class",
entity_id=cls.id,
file_id=cls.module.file_id,
embedding_type="interpreted",
embedding=result["embedding"],
content=result["text"],
tokens=result.get("tokens"),
repo_metadata=result["metadata"],
created_at=datetime.now(tz=UTC),
)
self.db_session.add(embedding)
await self.db_session.commit()
return cast("int", embedding.id)
async def _create_function_embedding(self, func: Function, file_path: str) -> int:
"""Create embedding for a function.
Args:
func: Function record
file_path: Path to the source file
Returns:
Created embedding ID
"""
# Check if embedding already exists
existing = await self._get_existing_embedding("function", cast("int", func.id))
if existing:
logger.debug("Embedding already exists for function %d", func.id)
return cast("int", existing.id)
# Prepare function data
func_data = {
"name": func.name,
"parameters": func.parameters,
"return_type": func.return_type,
"docstring": func.docstring,
"decorators": func.decorators,
"is_async": func.is_async,
"is_generator": func.is_generator,
"is_property": func.is_property,
"is_staticmethod": func.is_static,
"is_classmethod": func.is_classmethod,
"start_line": func.start_line,
"end_line": func.end_line,
"class_name": func.parent_class.name if func.class_id else None,
}
# Generate embedding
results = await self.embedding_generator.generate_function_embeddings(
[func_data],
file_path,
)
if not results or not results[0].get("embedding"):
msg = "Failed to generate function embedding"
raise EmbeddingError(msg)
result = results[0]
# Store embedding
embedding = CodeEmbedding(
entity_type="function",
entity_id=func.id,
file_id=func.module.file_id,
embedding_type="interpreted",
embedding=result["embedding"],
content=result["text"],
tokens=result.get("tokens"),
repo_metadata=result["metadata"],
created_at=datetime.now(tz=UTC),
)
self.db_session.add(embedding)
await self.db_session.commit()
return cast("int", embedding.id)
async def _get_existing_embedding(
self,
entity_type: str,
entity_id: int,
) -> CodeEmbedding | None:
"""Check if embedding already exists.
Args:
entity_type: Type of entity
entity_id: Entity ID
Returns:
Existing embedding or None
"""
sql_result = await self.db_session.execute(
select(CodeEmbedding).where(
and_(
CodeEmbedding.entity_type == entity_type,
CodeEmbedding.entity_id == entity_id,
),
),
)
return sql_result.scalar_one_or_none()
async def _get_file_modules(self, file_id: int) -> list[Module]:
"""Get modules for a file."""
sql_result = await self.db_session.execute(
select(Module)
.where(Module.file_id == file_id)
.options(selectinload(Module.file)),
)
return list(sql_result.scalars().all())
async def _get_file_classes(self, file_id: int) -> list[Class]:
"""Get classes for a file."""
sql_result = await self.db_session.execute(
select(Class)
.join(Module)
.where(Module.file_id == file_id)
.options(selectinload(Class.module)),
)
return list(sql_result.scalars().all())
async def _get_file_functions(self, file_id: int) -> list[Function]:
"""Get functions for a file."""
sql_result = await self.db_session.execute(
select(Function)
.join(Module)
.where(Module.file_id == file_id)
.options(
selectinload(Function.module),
selectinload(Function.parent_class),
),
)
return list(sql_result.scalars().all())
async def _get_class_methods(self, class_id: int) -> list[Function]:
"""Get methods for a class."""
sql_result = await self.db_session.execute(
select(Function).where(Function.class_id == class_id),
)
return list(sql_result.scalars().all())
async def _get_module_stats(self, module_id: int) -> dict[str, int]:
"""Get statistics for a module."""
# Count classes
class_sql_result = await self.db_session.execute(
select(func.count()).select_from(Class).where(Class.module_id == module_id),
)
class_count = class_sql_result.scalar() or 0
# Count functions
func_sql_result = await self.db_session.execute(
select(func.count())
.select_from(Function)
.where(
and_(
Function.module_id == module_id,
Function.class_id.is_(None),
),
),
)
func_count = func_sql_result.scalar() or 0
return {
"classes": class_count,
"functions": func_count,
}
async def delete_entity_embeddings(
self,
entity_type: str,
entity_id: int,
) -> int:
"""Delete embeddings for an entity.
Args:
entity_type: Type of entity
entity_id: Entity ID
Returns:
Number of deleted embeddings
"""
del_sql_result = await self.db_session.execute(
delete(CodeEmbedding).where(
and_(
CodeEmbedding.entity_type == entity_type,
CodeEmbedding.entity_id == entity_id,
),
)
)
await self.db_session.commit()
return int(del_sql_result.rowcount or 0)
async def update_file_embeddings(self, file_id: int) -> dict[str, Any]:
"""Update embeddings for a file (delete and recreate).
Args:
file_id: File ID to update
Returns:
Summary of updated embeddings
"""
logger.info("Updating embeddings for file %d", file_id)
# Delete existing embeddings
del_sql_result = await self.db_session.execute(
delete(CodeEmbedding).where(CodeEmbedding.file_id == file_id)
)
await self.db_session.commit()
logger.info("Deleted %d existing embeddings", del_sql_result.rowcount)
# Create new embeddings
stats = await self.create_file_embeddings(file_id)
stats["deleted"] = int(del_sql_result.rowcount or 0)
return stats
async def create_domain_entity_embedding(
self,
entity_id: int,
) -> dict[str, Any]:
"""Create embedding for a domain entity.
Args:
entity_id: Domain entity ID
Returns:
Embedding creation result
"""
logger.info("Creating embedding for domain entity %d", entity_id)
# Get domain entity with relationships
sql_result = await self.db_session.execute(
select(DomainEntity)
.where(DomainEntity.id == entity_id)
.options(selectinload(DomainEntity.bounded_contexts))
)
entity = sql_result.scalar_one_or_none()
if not entity:
msg = f"Domain entity {entity_id} not found"
raise NotFoundError(msg)
# Prepare entity data
entity_data = {
"name": entity.name,
"entity_type": entity.entity_type,
"description": entity.description,
"business_rules": entity.business_rules,
"invariants": entity.invariants,
"responsibilities": entity.responsibilities,
"module_path": entity.module_path,
"class_name": entity.class_name,
}
# Add bounded context info if available
if entity.bounded_contexts:
context_names = [bc.name for bc in entity.bounded_contexts]
entity_data["bounded_context"] = ", ".join(context_names)
# Note: result is a dict from generator; keep name distinct from SQL result
# Generate embedding
try:
gen_result = (
await self.embedding_generator.generate_domain_entity_embedding(
entity_data
)
)
# Store in database
embedding_record = CodeEmbedding(
entity_type="domain_entity",
entity_id=entity.id,
text=gen_result["text"],
embedding=gen_result["embedding"],
metadata=gen_result["metadata"],
tokens=gen_result["tokens"],
model=self.embedding_generator.embeddings.model,
created_at=datetime.now(tz=UTC),
)
self.db_session.add(embedding_record)
await self.db_session.commit()
return {
"entity_id": entity.id,
"entity_name": entity.name,
"entity_type": entity.entity_type,
"embedding_id": embedding_record.id,
"tokens": gen_result["tokens"],
"status": "success",
}
except Exception as e:
logger.exception(
"Failed to create embedding for domain entity %s",
entity.name,
)
return {
"entity_id": entity.id,
"entity_name": entity.name,
"entity_type": entity.entity_type,
"status": "failed",
"error": str(e),
}
async def create_all_domain_entity_embeddings(self) -> dict[str, Any]:
"""Create embeddings for all domain entities.
Returns:
Summary of created embeddings
"""
logger.info("Creating embeddings for all domain entities")
# Get all domain entities
sql_result = await self.db_session.execute(
select(DomainEntity).options(selectinload(DomainEntity.bounded_contexts))
)
entities = list(sql_result.scalars().all())
stats: dict[str, Any] = {
"total": len(entities),
"success": 0,
"failed": 0,
"errors": [],
}
for entity in entities:
result = await self.create_domain_entity_embedding(cast("int", entity.id))
if result["status"] == "success":
stats["success"] = int(stats["success"]) + 1
else:
stats["failed"] = int(stats["failed"]) + 1
stats["errors"].append(
f"{entity.name}: {result.get('error', 'Unknown error')}"
)
return stats