"""
GraphCodeBERT Embedding Generator
Generates vector embeddings for code frames using Microsoft's GraphCodeBERT model.
GraphCodeBERT adds data flow awareness to CodeBERT for better semantic understanding.
"""
import logging
from typing import Dict, Any, Optional, List, TYPE_CHECKING
import torch
from transformers import AutoTokenizer, AutoModel
import math
if TYPE_CHECKING:
import torch.cuda
from nabu.embeddings.base import EmbeddingGenerator, EmbeddingModel
from nabu.core.skeleton_builder import SkeletonFormatter, SkeletonOptions
logger = logging.getLogger(__name__)
class CodeBERTGenerator(EmbeddingGenerator):
"""
GraphCodeBERT embedding generator.
GraphCodeBERT extends CodeBERT with data flow awareness for improved
semantic understanding of code structure and relationships.
Used for: Non-linear consensus fusion (not persisted independently)
Strengths:
- Structure and data flow aware
- Better semantic understanding than vanilla CodeBERT
- Designed for code-related tasks (search, clone detection, etc.)
Best for:
- Semantic code search with natural language queries
- Understanding code semantics beyond token sequences
"""
def __init__(self):
"""Initialize GraphCodeBERT model and tokenizer."""
from .cache_config import get_model_cache_dir
model_name = "microsoft/graphcodebert-base"
cache_dir = get_model_cache_dir()
logger.info(f"Loading GraphCodeBERT model: {model_name}")
if cache_dir:
logger.info(f"Using cache directory: {cache_dir}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir
)
self.model = AutoModel.from_pretrained(
model_name,
cache_dir=cache_dir
)
self.model.eval() # Inference mode
# Use GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# CUDA streams for pipeline parallelism
self.cuda_streams = []
self.num_streams = 4 # Pipeline depth (optimal for 8-16GB VRAM)
if self.device.type == 'cuda':
try:
self.cuda_streams = [torch.cuda.Stream() for _ in range(self.num_streams)]
logger.info(f"CodeBERT: Created {self.num_streams} CUDA streams for pipeline parallelism")
except Exception as e:
logger.warning(f"Failed to create CUDA streams: {e}. Falling back to default stream.")
self.cuda_streams = []
# Initialize skeleton formatter for consistent skeleton generation
self.skeleton_formatter = SkeletonFormatter()
logger.info(f"GraphCodeBERT loaded on device: {self.device}")
@property
def model_type(self) -> EmbeddingModel:
return EmbeddingModel.CODEBERT
@property
def embedding_dim(self) -> int:
return 768
@property
def max_tokens(self) -> int:
return 512 # CodeBERT limit
def _truncate_to_token_limit(self, text: str, max_tokens: int) -> str:
"""
Truncate text to fit within token limit using tokenizer.
Args:
text: Text to truncate
max_tokens: Maximum number of tokens allowed
Returns:
Truncated text that fits within token limit
"""
if not text:
return ""
# Simple heuristic: ~4 chars per token for code
estimated_tokens = len(text) // 4
# Always use truncation for safety (prevents warnings)
tokens = self.tokenizer.encode(
text,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens # Use model's actual limit (512 for CodeBERT)
)
if len(tokens) <= max_tokens:
return text
# Truncate tokens and decode back
truncated_tokens = tokens[:max_tokens]
return self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
def generate_embedding_from_ast_frame(self, frame: 'AstFrameBase') -> Optional[List[float]]:
"""
Generate embedding from AstFrame with adaptive strategy.
Same strategy as UniXcoder for now.
Future: Could experiment with different strategies optimized for semantic matching.
Args:
frame: AstFrame with children loaded (includes control flow children)
Returns:
768-dimensional embedding as list of floats, or None if generation fails
"""
from nabu.core.frame_types import FrameNodeType
from nabu.core.skeleton_builder import SkeletonBuilder, SkeletonOptions
# Only embed CALLABLE frames
if frame.type != FrameNodeType.CALLABLE:
return None
try:
parts = []
# 1. Qualified name (context)
if frame.qualified_name:
parts.append(frame.qualified_name)
# 2. Adaptive content strategy based on function size
full_content = frame.content or ''
if full_content:
# Safely measure token count with defensive truncation
# Use simple char estimate: ~4 chars per token
char_limit = self.max_tokens * 4 # 512 * 4 = 2048 chars
safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content
try:
full_tokens = len(self.tokenizer.encode(
safe_sample,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens
))
# If we truncated the sample, assume it's large
if len(full_content) > char_limit:
full_tokens = self.max_tokens # Conservative estimate
except Exception as e:
logger.warning(f"Failed to measure tokens for {frame.qualified_name}: {e}, assuming large")
full_tokens = self.max_tokens # Assume large on error
if full_tokens <= int(self.max_tokens * 0.9):
# SMALL FUNCTION: Use full implementation
parts.append(full_content)
logger.debug(f"Embedding FULL code for {frame.qualified_name} ({full_tokens} tokens)")
else:
# LARGE FUNCTION: Use skeleton with control flow fingerprints
builder = SkeletonBuilder(db_manager=None)
options = SkeletonOptions(
detail_level="structure", # Full control flow tree
include_docstrings=True, # Critical for semantic signal
structure_detail_depth=2 # Nested control flow
)
skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0)
if skeleton:
parts.append(skeleton)
logger.debug(f"Embedding SKELETON for {frame.qualified_name} (full content: {full_tokens} tokens)")
else:
# Fallback: truncate full content
truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9))
parts.append(truncated)
logger.warning(f"Skeleton generation failed for {frame.qualified_name}, using truncated content")
# 3. Join and apply safety truncation (CodeBERT has 512 token limit vs UniXcoder's 1024)
text = "\n\n".join(parts)
text = self._truncate_to_token_limit(text, max_tokens=int(self.max_tokens * 0.9))
if not text:
return None
# 4. Generate embedding
return self.generate_embedding_from_text(text)
except Exception as e:
logger.error(f"Failed to generate embedding from AST frame {frame.qualified_name}: {e}")
return None
def generate_embedding_from_text(self, text: str) -> Optional[List[float]]:
"""
Generate embedding for text using CodeBERT.
Args:
text: Text to embed
Returns:
768-dimensional embedding or None on failure
"""
try:
# Tokenize with truncation (CodeBERT max length is 512 tokens)
tokens = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_tokens,
padding=True
)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
# Generate embedding
with torch.no_grad():
outputs = self.model(**tokens)
# Use [CLS] token embedding (standard for CodeBERT/RoBERTa models)
embedding = outputs.last_hidden_state[:, 0, :].squeeze()
# Convert to list and return
return embedding.cpu().numpy().tolist()
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def generate_embeddings_batch(
self,
frames: List['AstFrameBase'],
batch_size: int = 32
) -> List[Optional[List[float]]]:
"""
Generate embeddings for multiple frames using GPU-optimized batched inference.
Same implementation as UniXcoder but with CodeBERT model (512 token limit).
Args:
frames: List of frames to embed (CALLABLE frames only)
batch_size: Frames per batch (default: 32)
Returns:
List of embeddings in same order as input frames
"""
from nabu.core.frame_types import FrameNodeType
from nabu.core.skeleton_builder import SkeletonBuilder
import asyncio
logger.info(f"[CodeBERT] Generating embeddings for {len(frames)} frames (batch_size={batch_size})")
# Prepare texts for all frames (adaptive strategy per frame)
texts = []
for frame in frames:
if frame.type != FrameNodeType.CALLABLE:
texts.append(None)
continue
# Use same adaptive logic as single-frame generation
try:
parts = []
if frame.qualified_name:
parts.append(frame.qualified_name)
full_content = frame.content or ''
if full_content:
# Quick token estimate
char_limit = self.max_tokens * 4
safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content
try:
full_tokens = len(self.tokenizer.encode(
safe_sample,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens
))
if len(full_content) > char_limit:
full_tokens = self.max_tokens
except Exception:
full_tokens = self.max_tokens
if full_tokens <= int(self.max_tokens * 0.9):
parts.append(full_content)
else:
# Use skeleton for large functions
builder = SkeletonBuilder(db_manager=None)
options = SkeletonOptions(
detail_level="structure",
include_docstrings=True,
structure_detail_depth=2
)
skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0)
if skeleton:
parts.append(skeleton)
else:
truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9))
parts.append(truncated)
text = "\n\n".join(parts)
text = self._truncate_to_token_limit(text, max_tokens=int(self.max_tokens * 0.9))
texts.append(text if text else None)
except Exception as e:
logger.error(f"Failed to prepare text for {frame.qualified_name}: {e}")
texts.append(None)
# Generate embeddings with stream-based pipeline parallelism
results = [None] * len(frames)
total_batches = (len(texts) + batch_size - 1) // batch_size
# Determine pipeline depth (how many batches to process concurrently)
pipeline_depth = len(self.cuda_streams) if self.cuda_streams else 1
logger.info(f"[CodeBERT] Processing {total_batches} batches with pipeline_depth={pipeline_depth}")
# Process batches in groups (pipeline windows)
for window_start in range(0, total_batches, pipeline_depth):
# Prepare batch tasks for this pipeline window
batch_tasks = []
batch_metadata = [] # Track (batch_idx, batch_num, valid_indices)
for i in range(pipeline_depth):
batch_num = window_start + i
if batch_num >= total_batches:
break
batch_idx = batch_num * batch_size
batch_texts = texts[batch_idx:batch_idx + batch_size]
# Filter out None texts
valid_indices = [j for j, text in enumerate(batch_texts) if text is not None]
valid_texts = [batch_texts[j] for j in valid_indices]
if not valid_texts:
logger.debug(f"[CodeBERT] Batch {batch_num + 1}/{total_batches}: Skipped (no valid frames)")
continue
# Tokenize batch (CPU operation)
tokens = self.tokenizer(
valid_texts,
return_tensors="pt",
truncation=True,
max_length=self.max_tokens,
padding=True
)
# Select stream for this batch
stream = self.cuda_streams[i] if self.cuda_streams else None
# Create async task for this batch
loop = asyncio.get_event_loop()
task = loop.run_in_executor(
None,
self._generate_batch_embeddings_with_stream,
tokens,
stream
)
batch_tasks.append(task)
batch_metadata.append((batch_idx, batch_num + 1, valid_indices))
# Execute all batches in this window concurrently
if batch_tasks:
try:
batch_results = await asyncio.gather(*batch_tasks)
# Map embeddings back to results
for (batch_idx, batch_num, valid_indices), embeddings in zip(batch_metadata, batch_results):
for j, embedding in zip(valid_indices, embeddings):
results[batch_idx + j] = embedding
progress_pct = (batch_num * 100) // total_batches
logger.info(f"[CodeBERT] Batch {batch_num}/{total_batches} complete ({progress_pct}%)")
except Exception as e:
logger.error(f"[CodeBERT] Pipeline window failed: {e}")
raise RuntimeError(f"Embedding generation failed: {e}")
successful = sum(1 for r in results if r is not None)
logger.info(f"[CodeBERT] Completed: {successful}/{len(frames)} embeddings generated")
return results
def _generate_batch_embeddings_with_stream(
self,
tokens: dict,
stream: Optional['torch.cuda.Stream'] = None
) -> List[List[float]]:
"""
Synchronous batch embedding generation with optional CUDA stream.
Args:
tokens: Tokenized batch (dict with 'input_ids', 'attention_mask', etc.)
stream: Optional CUDA stream for pipeline parallelism
Returns:
List of embeddings (one per sample in batch)
"""
# Move tokens to device on specified stream
if stream is not None and self.device.type == 'cuda':
with torch.cuda.stream(stream):
tokens = {k: v.to(self.device, non_blocking=True) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :].cpu()
# Synchronize stream before returning to ensure data is ready
stream.synchronize()
else:
# Fallback: default stream (backward compatible)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :].cpu()
return embeddings.numpy().tolist()