"""
UniXcoder Embedding Generator
Generates vector embeddings for code frames using Microsoft's UniXcoder model.
Optimized for operational differentiation and code→code similarity.
"""
import logging
from typing import Dict, Any, Optional, List, TYPE_CHECKING
import torch
from transformers import AutoTokenizer, AutoModel
if TYPE_CHECKING:
import torch.cuda
from nabu.embeddings.base import EmbeddingGenerator, EmbeddingModel
from nabu.core.skeleton_builder import SkeletonFormatter, SkeletonOptions
logger = logging.getLogger(__name__)
class UniXcoderGenerator(EmbeddingGenerator):
"""
UniXcoder embedding generator.
Optimized for operational differentiation and code→code similarity.
Stores embeddings in: embedding_unixcoder column
Strengths:
- Excellent operational differentiation
- Low false positive rate for code similarity
- Code→code pattern matching
Best for: clone detection, code→code search
"""
def __init__(self):
"""Initialize UniXcoder model and tokenizer."""
from .cache_config import get_model_cache_dir
model_name = "microsoft/unixcoder-base"
cache_dir = get_model_cache_dir()
logger.info(f"Loading UniXcoder model: {model_name}")
if cache_dir:
logger.info(f"Using cache directory: {cache_dir}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir
)
self.model = AutoModel.from_pretrained(
model_name,
cache_dir=cache_dir
)
self.model.eval() # Inference mode
# Use GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# CUDA streams for pipeline parallelism
self.cuda_streams = []
self.num_streams = 4 # Pipeline depth (optimal for 8-16GB VRAM)
if self.device.type == 'cuda':
try:
self.cuda_streams = [torch.cuda.Stream() for _ in range(self.num_streams)]
logger.info(f"UniXcoder: Created {self.num_streams} CUDA streams for pipeline parallelism")
except Exception as e:
logger.warning(f"Failed to create CUDA streams: {e}. Falling back to default stream.")
self.cuda_streams = []
# Initialize skeleton formatter for consistent skeleton generation
self.skeleton_formatter = SkeletonFormatter()
logger.info(f"UniXcoder loaded on device: {self.device}")
@property
def model_type(self) -> EmbeddingModel:
return EmbeddingModel.UNIXCODER
@property
def embedding_dim(self) -> int:
return 768
@property
def max_tokens(self) -> int:
return 1024 # UniXcoder's longer context
def _truncate_to_token_limit(self, text: str, max_tokens: int) -> str:
"""
Truncate text to fit within token limit using tokenizer.
Args:
text: Text to truncate
max_tokens: Maximum number of tokens allowed
Returns:
Truncated text that fits within token limit
"""
if not text:
return ""
# Simple heuristic: ~4 chars per token for code
estimated_tokens = len(text) // 4
# Always use truncation for safety (prevents warnings)
tokens = self.tokenizer.encode(
text,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens # Use model's actual limit (1024 for UniXcoder)
)
if len(tokens) <= max_tokens:
return text
# Truncate tokens and decode back
truncated_tokens = tokens[:max_tokens]
return self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
def generate_embedding_from_ast_frame(self, frame: 'AstFrameBase') -> Optional[List[float]]:
"""
Generate embedding from AstFrame with adaptive strategy.
Adaptive Approach:
- Small functions (≤400 tokens): Embed FULL implementation for maximum fidelity
- Large functions (>400 tokens): Use skeleton with control flow fingerprints
This solves the "..." placeholder problem: UniXcoder was trained on real code,
not abstracted skeletons. Small functions benefit from seeing actual implementation.
Args:
frame: AstFrame with children loaded (includes control flow children)
Returns:
768-dimensional embedding as list of floats, or None if generation fails
"""
from nabu.core.frame_types import FrameNodeType
from nabu.core.skeleton_builder import SkeletonBuilder, SkeletonOptions
# Only embed CALLABLE frames
if frame.type != FrameNodeType.CALLABLE:
return None
try:
parts = []
# 1. Qualified name (context)
if frame.qualified_name:
parts.append(frame.qualified_name)
# 2. Adaptive content strategy based on function size
full_content = frame.content or ''
if full_content:
# Safely measure token count with defensive truncation
# Use simple char estimate: ~4 chars per token
char_limit = self.max_tokens * 4 # 1024 * 4 = 4096 chars
safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content
try:
full_tokens = len(self.tokenizer.encode(
safe_sample,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens
))
# If we truncated the sample, assume it's large
if len(full_content) > char_limit:
full_tokens = self.max_tokens # Conservative estimate
except Exception as e:
logger.warning(f"Failed to measure tokens for {frame.qualified_name}: {e}, assuming large")
full_tokens = self.max_tokens # Assume large on error # Assume large on error
if full_tokens <= int(self.max_tokens * 0.9):
# SMALL FUNCTION: Use full implementation
# UniXcoder can see actual operations (dict lookup, math, logic, etc.)
# This solves the 99.69% false positive problem for simple functions
parts.append(full_content)
logger.debug(f"Embedding FULL code for {frame.qualified_name} ({full_tokens} tokens)")
else:
# LARGE FUNCTION: Use skeleton with control flow fingerprints
# This keeps token budget manageable while providing structural differentiation
builder = SkeletonBuilder(db_manager=None)
options = SkeletonOptions(
detail_level="structure", # Full control flow tree
include_docstrings=True, # Critical for semantic signal
structure_detail_depth=2 # Nested control flow
)
skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0)
if skeleton:
parts.append(skeleton)
logger.debug(f"Embedding SKELETON for {frame.qualified_name} (full content: {full_tokens} tokens)")
else:
# Fallback: truncate full content
truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9))
parts.append(truncated)
logger.warning(f"Skeleton generation failed for {frame.qualified_name}, using truncated content")
# 3. Join and apply safety truncation
text = "\n\n".join(parts)
text = self._truncate_to_token_limit(text, max_tokens=500)
if not text:
return None
# 4. Generate embedding
return self.generate_embedding_from_text(text)
except Exception as e:
logger.error(f"Failed to generate embedding from AST frame {frame.qualified_name}: {e}")
return None
def generate_embedding_from_text(self, text: str) -> Optional[List[float]]:
"""
Generate embedding for text using UniXcoder.
Args:
text: Text to embed
Returns:
768-dimensional embedding or None on failure
"""
try:
# Tokenize with truncation (UniXcoder max length is 1024 tokens)
tokens = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_tokens,
padding=True
)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
# Generate embedding
with torch.no_grad():
outputs = self.model(**tokens)
# Use [CLS] token embedding (standard for UniXcoder/RoBERTa models)
embedding = outputs.last_hidden_state[:, 0, :].squeeze()
# Convert to list and return
return embedding.cpu().numpy().tolist()
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def generate_embeddings_batch(
self,
frames: List['AstFrameBase'],
batch_size: int = 32
) -> List[Optional[List[float]]]:
"""
Generate embeddings for multiple frames using GPU-optimized batched inference.
Processes frames in batches to maximize GPU utilization (~10-15x speedup vs sequential).
Args:
frames: List of frames to embed (CALLABLE frames only)
batch_size: Frames per batch (default: 32, optimal for 8GB VRAM)
Returns:
List of embeddings in same order as input frames
"""
from nabu.core.frame_types import FrameNodeType
from nabu.core.skeleton_builder import SkeletonBuilder
import asyncio
logger.info(f"[UniXcoder] Generating embeddings for {len(frames)} frames (batch_size={batch_size})")
# Prepare texts for all frames (adaptive strategy per frame)
texts = []
for frame in frames:
if frame.type != FrameNodeType.CALLABLE:
texts.append(None)
continue
# Use same adaptive logic as single-frame generation
try:
parts = []
if frame.qualified_name:
parts.append(frame.qualified_name)
full_content = frame.content or ''
if full_content:
# Quick token estimate
char_limit = self.max_tokens * 4
safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content
try:
full_tokens = len(self.tokenizer.encode(
safe_sample,
add_special_tokens=False,
truncation=True,
max_length=self.max_tokens
))
if len(full_content) > char_limit:
full_tokens = self.max_tokens
except Exception:
full_tokens = self.max_tokens
if full_tokens <= int(self.max_tokens * 0.9):
parts.append(full_content)
else:
# Use skeleton for large functions
builder = SkeletonBuilder(db_manager=None)
options = SkeletonOptions(
detail_level="structure",
include_docstrings=True,
structure_detail_depth=2
)
skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0)
if skeleton:
parts.append(skeleton)
else:
truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9))
parts.append(truncated)
text = "\n\n".join(parts)
text = self._truncate_to_token_limit(text, max_tokens=500)
texts.append(text if text else None)
except Exception as e:
logger.error(f"Failed to prepare text for {frame.qualified_name}: {e}")
texts.append(None)
# Generate embeddings with stream-based pipeline parallelism
results = [None] * len(frames)
total_batches = (len(texts) + batch_size - 1) // batch_size
# Determine pipeline depth (how many batches to process concurrently)
pipeline_depth = len(self.cuda_streams) if self.cuda_streams else 1
logger.info(f"[UniXcoder] Processing {total_batches} batches with pipeline_depth={pipeline_depth}")
# Process batches in groups (pipeline windows)
for window_start in range(0, total_batches, pipeline_depth):
# Prepare batch tasks for this pipeline window
batch_tasks = []
batch_metadata = [] # Track (batch_idx, batch_num, valid_indices)
for i in range(pipeline_depth):
batch_num = window_start + i
if batch_num >= total_batches:
break
batch_idx = batch_num * batch_size
batch_texts = texts[batch_idx:batch_idx + batch_size]
# Filter out None texts
valid_indices = [j for j, text in enumerate(batch_texts) if text is not None]
valid_texts = [batch_texts[j] for j in valid_indices]
if not valid_texts:
logger.debug(f"[UniXcoder] Batch {batch_num + 1}/{total_batches}: Skipped (no valid frames)")
continue
# Tokenize batch (CPU operation)
tokens = self.tokenizer(
valid_texts,
return_tensors="pt",
truncation=True,
max_length=self.max_tokens,
padding=True
)
# Select stream for this batch
stream = self.cuda_streams[i] if self.cuda_streams else None
# Create async task for this batch
loop = asyncio.get_event_loop()
task = loop.run_in_executor(
None,
self._generate_batch_embeddings_with_stream,
tokens,
stream
)
batch_tasks.append(task)
batch_metadata.append((batch_idx, batch_num + 1, valid_indices))
# Execute all batches in this window concurrently
if batch_tasks:
try:
batch_results = await asyncio.gather(*batch_tasks)
# Map embeddings back to results
for (batch_idx, batch_num, valid_indices), embeddings in zip(batch_metadata, batch_results):
for j, embedding in zip(valid_indices, embeddings):
results[batch_idx + j] = embedding
progress_pct = (batch_num * 100) // total_batches
logger.info(f"[UniXcoder] Batch {batch_num}/{total_batches} complete ({progress_pct}%)")
except Exception as e:
logger.error(f"[UniXcoder] Pipeline window failed: {e}")
raise RuntimeError(f"Embedding generation failed: {e}")
successful = sum(1 for r in results if r is not None)
logger.info(f"[UniXcoder] Completed: {successful}/{len(frames)} embeddings generated")
return results
def _generate_batch_embeddings_with_stream(
self,
tokens: dict,
stream: Optional['torch.cuda.Stream'] = None
) -> List[List[float]]:
"""
Synchronous batch embedding generation with optional CUDA stream.
Args:
tokens: Tokenized batch (dict with 'input_ids', 'attention_mask', etc.)
stream: Optional CUDA stream for pipeline parallelism
Returns:
List of embeddings (one per sample in batch)
"""
# Move tokens to device on specified stream
if stream is not None and self.device.type == 'cuda':
with torch.cuda.stream(stream):
tokens = {k: v.to(self.device, non_blocking=True) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :].cpu()
# Synchronize stream before returning to ensure data is ready
stream.synchronize()
else:
# Fallback: default stream (backward compatible)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :].cpu()
return embeddings.numpy().tolist()