Skip to main content
Glama
codebert_generator.py11.9 kB
""" GraphCodeBERT Embedding Generator Generates vector embeddings for code frames using Microsoft's GraphCodeBERT model. GraphCodeBERT adds data flow awareness to CodeBERT for better semantic understanding. """ import logging from typing import Optional, List, TYPE_CHECKING import torch import math if TYPE_CHECKING: import torch.cuda from nabu.embeddings.base import EmbeddingModel from nabu.embeddings.base_transformer_generator import BaseTransformerEmbeddingGenerator logger = logging.getLogger(__name__) class CodeBERTGenerator(BaseTransformerEmbeddingGenerator): """ GraphCodeBERT embedding generator. GraphCodeBERT extends CodeBERT with data flow awareness for improved semantic understanding of code structure and relationships. Used for: Non-linear consensus fusion (not persisted independently) Strengths: - Structure and data flow aware - Better semantic understanding than vanilla CodeBERT - Designed for code-related tasks (search, clone detection, etc.) Best for: - Semantic code search with natural language queries - Understanding code semantics beyond token sequences """ def __init__(self): """Initialize GraphCodeBERT model and tokenizer.""" super().__init__( model_name="microsoft/graphcodebert-base", embedding_dim=768, max_tokens=512 ) @property def model_type(self) -> EmbeddingModel: """Model type identifier for database columns.""" return EmbeddingModel.CODEBERT def generate_embedding_from_ast_frame(self, frame: 'AstFrameBase') -> Optional[List[float]]: """ Generate embedding from AstFrame with adaptive strategy. Same strategy as UniXcoder for now. Future: Could experiment with different strategies optimized for semantic matching. Args: frame: AstFrame with children loaded (includes control flow children) Returns: 768-dimensional embedding as list of floats, or None if generation fails """ from nabu.core.frame_types import FrameNodeType from nabu.core.skeleton_builder import SkeletonBuilder, SkeletonOptions # Only embed CALLABLE frames if frame.type != FrameNodeType.CALLABLE: return None try: parts = [] # 1. Qualified name (context) if frame.qualified_name: parts.append(frame.qualified_name) # 2. Adaptive content strategy based on function size full_content = frame.content or '' if full_content: # Safely measure token count with defensive truncation # Use simple char estimate: ~4 chars per token char_limit = self.max_tokens * 4 # 512 * 4 = 2048 chars safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content try: full_tokens = len(self.tokenizer.encode( safe_sample, add_special_tokens=False, truncation=True, max_length=self.max_tokens )) # If we truncated the sample, assume it's large if len(full_content) > char_limit: full_tokens = self.max_tokens # Conservative estimate except Exception as e: logger.warning(f"Failed to measure tokens for {frame.qualified_name}: {e}, assuming large") full_tokens = self.max_tokens # Assume large on error if full_tokens <= int(self.max_tokens * 0.9): # SMALL FUNCTION: Use full implementation parts.append(full_content) logger.debug(f"Embedding FULL code for {frame.qualified_name} ({full_tokens} tokens)") else: # LARGE FUNCTION: Use skeleton with control flow fingerprints builder = SkeletonBuilder(db_manager=None) options = SkeletonOptions( detail_level="structure", # Full control flow tree include_docstrings=True, # Critical for semantic signal structure_detail_depth=2 # Nested control flow ) skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0) if skeleton: parts.append(skeleton) logger.debug(f"Embedding SKELETON for {frame.qualified_name} (full content: {full_tokens} tokens)") else: # Fallback: truncate full content truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9)) parts.append(truncated) logger.warning(f"Skeleton generation failed for {frame.qualified_name}, using truncated content") # 3. Join and apply safety truncation (CodeBERT has 512 token limit vs UniXcoder's 1024) text = "\n\n".join(parts) text = self._truncate_to_token_limit(text, max_tokens=int(self.max_tokens * 0.9)) if not text: return None # 4. Generate embedding return self.generate_embedding_from_text(text) except Exception as e: logger.error(f"Failed to generate embedding from AST frame {frame.qualified_name}: {e}") return None async def generate_embeddings_batch(self, frames: List['AstFrameBase'], batch_size: int = 32) -> List[Optional[List[float]]]: """ Generate embeddings for multiple frames in parallel batches. Uses CUDA streams for pipeline parallelism when available. Args: frames: List of AstFrame objects to embed batch_size: Number of frames per batch (default 32) Returns: List of embeddings (same order as input), None for failed frames """ from nabu.core.frame_types import FrameNodeType from nabu.core.skeleton_builder import SkeletonBuilder, SkeletonOptions # Filter to CALLABLE frames callable_frames = [f for f in frames if f.type == FrameNodeType.CALLABLE] if not callable_frames: return [None] * len(frames) logger.info(f"[CodeBERT] Generating embeddings for {len(callable_frames)} frames (batch_size={batch_size})") # Build adaptive text for each frame texts = [] builder = SkeletonBuilder(db_manager=None) for frame in callable_frames: parts = [] if frame.qualified_name: parts.append(frame.qualified_name) full_content = frame.content or '' if full_content: char_limit = self.max_tokens * 4 safe_sample = full_content[:char_limit] if len(full_content) > char_limit else full_content try: full_tokens = len(self.tokenizer.encode( safe_sample, add_special_tokens=False, truncation=True, max_length=self.max_tokens )) if len(full_content) > char_limit: full_tokens = self.max_tokens except Exception: full_tokens = self.max_tokens if full_tokens <= int(self.max_tokens * 0.9): parts.append(full_content) else: options = SkeletonOptions( detail_level="structure", include_docstrings=True, structure_detail_depth=2 ) skeleton = builder.build_skeleton_from_ast(frame, options, max_recursion_depth=0) if skeleton: parts.append(skeleton) else: truncated = self._truncate_to_token_limit(full_content, int(self.max_tokens * 0.9)) parts.append(truncated) text = "\n\n".join(parts) text = self._truncate_to_token_limit(text, max_tokens=int(self.max_tokens * 0.9)) texts.append(text if text else "") # Batch generation with CUDA streams embeddings = [] num_batches = math.ceil(len(texts) / batch_size) logger.info(f"[CodeBERT] Processing {num_batches} batches") for i in range(num_batches): batch_texts = texts[i * batch_size:(i + 1) * batch_size] try: # Tokenize batch tokens = self.tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=self.max_tokens, padding=True ) tokens = {k: v.to(self.device) for k, v in tokens.items()} # Use CUDA stream if available stream = self.cuda_streams[i % self.num_streams] if self.cuda_streams else None batch_embeddings = self._generate_batch_embeddings_with_stream(tokens, stream) embeddings.extend(batch_embeddings) # Log progress progress_pct = ((i + 1) * 100) // num_batches logger.info(f"[CodeBERT] Batch {i + 1}/{num_batches} complete ({progress_pct}%)") except Exception as e: logger.error(f"[CodeBERT] Batch {i + 1} failed: {e}") embeddings.extend([None] * len(batch_texts)) # Map back to original frame order result = [] emb_idx = 0 for frame in frames: if frame.type == FrameNodeType.CALLABLE: result.append(embeddings[emb_idx] if emb_idx < len(embeddings) else None) emb_idx += 1 else: result.append(None) successful = sum(1 for e in result if e is not None) logger.info(f"[CodeBERT] Completed: {successful}/{len(frames)} embeddings generated") return result def _generate_batch_embeddings_with_stream(self, tokens: dict, stream: Optional['torch.cuda.Stream'] = None) -> List[List[float]]: """ Generate embeddings for a tokenized batch, optionally using CUDA stream. Args: tokens: Tokenized batch (dict with input_ids, attention_mask, etc.) stream: Optional CUDA stream for async execution Returns: List of embedding vectors """ try: if stream and self.device.type == 'cuda': with torch.cuda.stream(stream): with torch.no_grad(): outputs = self.model(**tokens) embeddings = outputs.last_hidden_state[:, 0, :].cpu() torch.cuda.current_stream().wait_stream(stream) else: with torch.no_grad(): outputs = self.model(**tokens) embeddings = outputs.last_hidden_state[:, 0, :].cpu() return [emb.numpy().tolist() for emb in embeddings] except Exception as e: logger.error(f"Batch embedding generation failed: {e}") return [None] * len(tokens['input_ids'])

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/y3i12/nabu_nisaba'

If you have feedback or need assistance with the MCP directory API, please join our Discord server