"""
Embedding repository for generating and managing embeddings
"""
import logging
import json
import re
from typing import Optional, List
from sqlalchemy import Engine, text
from sqlalchemy.exc import SQLAlchemyError
from shared.exceptions import EmbeddingError
logger = logging.getLogger(__name__)
class EmbeddingRepository:
"""
Repository for embedding generation and management
Extracted from llmDatabaseRouter._generate_embedding()
"""
def __init__(self, engine: Engine):
self.engine = engine
def generate_embedding(self, text_input: str) -> Optional[List[float]]:
"""
Generate embedding for text using available methods
Extracted from llmDatabaseRouter._generate_embedding()
Args:
text_input: Text to generate embedding for
Returns:
Embedding as list of floats, or None if not available
"""
try:
# First check if the embedding column exists in this database
if not self._has_embedding_column():
return None
# Method 1: Try to use an existing similar embedding as a proxy
# This finds documents with similar keywords and uses their embeddings
search_terms = self._extract_search_terms(text_input)
if search_terms:
similar_embedding = self._find_similar_embedding(search_terms)
if similar_embedding:
return self._parse_embedding(similar_embedding)
# Method 2: If no similar documents found, use a representative embedding
# Get an embedding from documents with good semantic content
representative_embedding = self._get_representative_embedding()
if representative_embedding:
return self._parse_embedding(representative_embedding)
return None
except Exception as e:
logger.error(f"Error generating embedding: {e}")
return None
def _has_embedding_column(self) -> bool:
"""Check if the embedding column exists in doc_embeddings table"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'doc_embeddings'
AND column_name = 'embedding'
AND table_schema = 'public'
);
""")).scalar()
return bool(result)
except Exception as e:
logger.warning(f"Could not check embedding column existence: {e}")
return False
def _find_similar_embedding(self, search_terms: List[str]) -> Optional[str]:
"""Find an embedding from a document with similar terms"""
try:
with self.engine.connect() as connection:
# Build a query to find documents with similar terms
search_conditions = []
params = {}
for i, term in enumerate(search_terms[:3]): # Use top 3 terms
param_name = f'term_{i}'
search_conditions.append(f"snippet ILIKE :{param_name}")
params[param_name] = f'%{term}%'
if search_conditions:
result = connection.execute(text(f"""
SELECT embedding
FROM doc_embeddings
WHERE ({' OR '.join(search_conditions)})
AND embedding IS NOT NULL
LIMIT 1
"""), params)
row = result.fetchone()
if row and row[0]:
return str(row[0])
return None
except Exception as e:
logger.warning(f"Could not find similar embedding: {e}")
return None
def _get_representative_embedding(self) -> Optional[str]:
"""Get a representative embedding from existing documents"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("""
SELECT embedding
FROM doc_embeddings
WHERE embedding IS NOT NULL
AND LENGTH(snippet) > 50 -- Prefer documents with good content
ORDER BY created_at DESC
LIMIT 1
"""))
row = result.fetchone()
if row and row[0]:
return str(row[0])
return None
except Exception as e:
logger.warning(f"Could not get representative embedding: {e}")
return None
def _extract_search_terms(self, text_input: str) -> List[str]:
"""
Extract meaningful search terms from text
"""
try:
# Remove common stop words and extract meaningful terms
stop_words = {
'the', 'is', 'are', 'was', 'were', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
'of', 'with', 'by', 'what', 'who', 'where', 'when', 'why', 'how', 'can', 'could', 'would',
'should', 'do', 'does', 'did', 'will', 'have', 'has', 'had', 'be', 'been', 'being'
}
# Extract words that are 3+ characters and not stop words
words = re.findall(r'\b[a-zA-Z]{3,}\b', text_input.lower())
meaningful_terms = [word for word in words if word not in stop_words]
# Limit to top 5 terms to avoid overly complex queries
return meaningful_terms[:5]
except Exception as e:
logger.warning(f"Could not extract search terms: {e}")
return [text_input.lower()]
def _parse_embedding(self, embedding_str: str) -> Optional[List[float]]:
"""Parse embedding string to list of floats"""
try:
# Handle different embedding string formats
if embedding_str.startswith('[') and embedding_str.endswith(']'):
# JSON array format
import json
return json.loads(embedding_str)
elif ',' in embedding_str:
# Comma-separated values
return [float(x.strip()) for x in embedding_str.split(',')]
else:
# Single value or unknown format
logger.warning(f"Unknown embedding format: {embedding_str[:50]}...")
return None
except Exception as e:
logger.error(f"Could not parse embedding: {e}")
return None
def store_embedding(
self,
text: str,
embedding: List[float],
table_name: str,
pk_json: dict,
metadata: Optional[dict] = None
) -> bool:
"""
Store an embedding in the database
Args:
text: Original text content
embedding: Embedding vector
table_name: Source table name
pk_json: Primary key JSON
metadata: Optional metadata
Returns:
True if stored successfully, False otherwise
"""
try:
with self.engine.connect() as connection:
connection.execute(text("""
INSERT INTO doc_embeddings (
table_name, pk_json, snippet, embedding, created_at
) VALUES (
:table_name, :pk_json::jsonb, :snippet, :embedding::vector, NOW()
)
"""), {
'table_name': table_name,
'pk_json': json.dumps(pk_json),
'snippet': text,
'embedding': str(embedding)
})
connection.commit()
return True
except Exception as e:
logger.error(f"Could not store embedding: {e}")
return False
def get_embedding_stats(self) -> dict:
"""Get statistics about stored embeddings"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("""
SELECT
COUNT(*) as total,
COUNT(DISTINCT table_name) as unique_tables,
AVG(LENGTH(snippet)) as avg_snippet_length
FROM doc_embeddings
WHERE embedding IS NOT NULL
"""))
row = result.fetchone()
if row:
return {
'total_embeddings': row[0],
'unique_tables': row[1],
'avg_snippet_length': float(row[2]) if row[2] else 0.0
}
return {'total_embeddings': 0, 'unique_tables': 0, 'avg_snippet_length': 0.0}
except Exception as e:
logger.error(f"Could not get embedding stats: {e}")
return {'error': str(e)}