"""
Vector/embedding repository for semantic search operations
"""
import logging
import json
import re
from typing import Dict, List, Optional, Any
from sqlalchemy import Engine, text
from sqlalchemy.exc import SQLAlchemyError
from shared.models import SemanticSearchResult
from shared.exceptions import EmbeddingError
logger = logging.getLogger(__name__)
class VectorRepository:
"""
Repository for vector/embedding operations
Handles semantic search, text search fallback, and embedding storage
"""
def __init__(self, engine: Engine):
self.engine = engine
self._has_vector_extension = None
def search_embeddings(
self,
query_embedding: List[float],
table_name: str = "doc_embeddings",
table_filter: Optional[str] = None,
fk_filter: Optional[Dict] = None,
limit: int = 20
) -> List[SemanticSearchResult]:
"""
Search for similar documents using vector similarity
Extracted and enhanced from llmDatabaseRouter.semantic_rows()
Args:
query_embedding: Query vector as list of floats
table_name: Name of the embeddings table
table_filter: Optional table name filter
fk_filter: Optional foreign key filter
limit: Maximum number of results
Returns:
List of semantic search results
"""
if not self.has_vector_extension():
raise EmbeddingError("pgvector extension not available")
try:
# Build the query using original schema from llmDatabaseRouter
sql_parts = ["""
SELECT
table_name,
pk_json,
snippet,
created_at,
(embedding <=> :question_embedding::vector) AS distance
FROM doc_embeddings
WHERE 1=1
"""]
params = {'question_embedding': str(query_embedding)}
if table_filter:
sql_parts.append("AND table_name = :table_filter")
params['table_filter'] = table_filter
if fk_filter:
sql_parts.append("AND pk_json @> :fk_filter::jsonb")
params['fk_filter'] = json.dumps(fk_filter)
# Order by similarity (lower distance = more similar)
sql_parts.append(f"ORDER BY distance ASC LIMIT {limit}")
final_sql = " ".join(sql_parts)
with self.engine.connect() as conn:
result = conn.execute(text(final_sql), params)
rows = result.fetchall()
results = []
for row in rows:
# Convert distance to similarity score (1 - distance)
similarity_score = 1.0 - float(row.distance) if row.distance is not None else 0.0
results.append(SemanticSearchResult(
content=row.snippet or "",
score=similarity_score,
source=row.table_name or "unknown",
metadata={
'pk_json': row.pk_json,
'created_at': row.created_at.isoformat() if row.created_at else None
}
))
logger.info(f"Vector search returned {len(results)} results")
return results
except SQLAlchemyError as e:
logger.error(f"Vector search failed: {e}")
raise EmbeddingError(f"Vector search query failed: {e}")
except Exception as e:
logger.error(f"Unexpected error in vector search: {e}")
raise EmbeddingError(f"Vector search failed: {e}")
def semantic_search_with_fallback(
self,
question: str,
query_embedding: Optional[List[float]] = None,
table_filter: Optional[str] = None,
fk_filter: Optional[Dict] = None,
limit: int = 20
) -> List[SemanticSearchResult]:
"""
Main semantic search method with automatic fallback to text search
Replaces the original semantic_rows() method
Args:
question: Original question text
query_embedding: Pre-computed embedding (if available)
table_filter: Optional table name filter
fk_filter: Optional foreign key filter
limit: Maximum number of results
Returns:
List of semantic search results
"""
try:
# If we have an embedding and vector extension, use vector search
if query_embedding and self.has_vector_extension():
return self.search_embeddings(
query_embedding=query_embedding,
table_filter=table_filter,
fk_filter=fk_filter,
limit=limit
)
else:
# Fallback to text search
return self.text_search_fallback(
question=question,
table_filter=table_filter,
fk_filter=fk_filter,
limit=limit
)
except Exception as e:
logger.error(f"Semantic search failed, falling back to text search: {e}")
# Always fallback to text search if vector search fails
return self.text_search_fallback(
question=question,
table_filter=table_filter,
fk_filter=fk_filter,
limit=limit
)
def text_search_fallback(
self,
question: str,
table_filter: Optional[str] = None,
fk_filter: Optional[Dict] = None,
limit: int = 20
) -> List[SemanticSearchResult]:
"""
Fallback text search when vector search is not available
Extracted from llmDatabaseRouter._text_search_fallback()
Args:
question: Search question
table_filter: Optional table name filter
fk_filter: Optional foreign key filter
limit: Maximum number of results
Returns:
List of search results
"""
try:
sql_parts = ["SELECT table_name, pk_json, snippet, created_at FROM doc_embeddings WHERE 1=1"]
params = {}
if table_filter:
sql_parts.append("AND table_name = :table_filter")
params['table_filter'] = table_filter
if fk_filter:
sql_parts.append("AND pk_json @> :fk_filter::jsonb")
params['fk_filter'] = json.dumps(fk_filter)
# Improved text search with keyword extraction and broader matching
search_terms = self._extract_search_terms(question)
if search_terms:
# Build OR conditions for multiple search terms
search_conditions = []
for i, term in enumerate(search_terms):
param_name = f'query_{i}'
search_conditions.append(f"snippet ILIKE :{param_name}")
params[param_name] = f'%{term}%'
sql_parts.append(f"AND ({' OR '.join(search_conditions)})")
else:
# Fallback to original query
sql_parts.append("AND snippet ILIKE :query")
params['query'] = f'%{question}%'
sql_parts.append(f"ORDER BY created_at DESC LIMIT {limit}")
final_sql = " ".join(sql_parts)
with self.engine.connect() as conn:
result = conn.execute(text(final_sql), params)
rows = result.fetchall()
results = []
for row in rows:
results.append(SemanticSearchResult(
content=row.snippet or "",
score=0.5, # Default score for text search
source=row.table_name or "unknown",
metadata={
'pk_json': row.pk_json,
'created_at': row.created_at.isoformat() if row.created_at else None
}
))
logger.info(f"Text search returned {len(results)} results")
return results
except SQLAlchemyError as e:
logger.error(f"Text search failed: {e}")
raise EmbeddingError(f"Text search query failed: {e}")
except Exception as e:
logger.error(f"Unexpected error in text search: {e}")
raise EmbeddingError(f"Text search failed: {e}")
def _extract_search_terms(self, question: str) -> List[str]:
"""
Extract meaningful search terms from a question
Extracted from llmDatabaseRouter._extract_search_terms()
"""
try:
# Remove common stop words and extract meaningful terms
stop_words = {
'the', 'is', 'are', 'was', 'were', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
'of', 'with', 'by', 'what', 'who', 'where', 'when', 'why', 'how', 'can', 'could', 'would',
'should', 'do', 'does', 'did', 'will', 'have', 'has', 'had', 'be', 'been', 'being'
}
# Extract words that are 3+ characters and not stop words
words = re.findall(r'\b[a-zA-Z]{3,}\b', question.lower())
meaningful_terms = [word for word in words if word not in stop_words]
# Limit to top 5 terms to avoid overly complex queries
return meaningful_terms[:5]
except Exception as e:
logger.warning(f"Could not extract search terms: {e}")
return [question.lower()]
def store_embedding(
self,
content: str,
embedding: List[float],
source: str,
metadata: Optional[Dict[str, Any]] = None,
table_name: str = "document_embeddings"
) -> str:
"""
Store a document embedding
Args:
content: Document content
embedding: Vector embedding
source: Source identifier
metadata: Additional metadata
table_name: Name of the embeddings table
Returns:
ID of the stored embedding
"""
try:
sql = f"""
INSERT INTO {table_name} (content, embedding, source, metadata, created_at)
VALUES (:content, :embedding::vector, :source, :metadata, NOW())
RETURNING id
"""
params = {
'content': content,
'embedding': str(embedding),
'source': source,
'metadata': metadata or {}
}
with self.engine.connect() as conn:
result = conn.execute(text(sql), params)
embedding_id = result.fetchone()[0]
conn.commit()
logger.info(f"Stored embedding with ID: {embedding_id}")
return str(embedding_id)
except SQLAlchemyError as e:
logger.error(f"Failed to store embedding: {e}")
raise EmbeddingError(f"Could not store embedding: {e}")
def has_vector_extension(self) -> bool:
"""Check if pgvector extension is available (cached)"""
if self._has_vector_extension is None:
try:
with self.engine.connect() as conn:
result = conn.execute(text("""
SELECT EXISTS(
SELECT 1 FROM pg_extension WHERE extname = 'vector'
) as has_vector
"""))
self._has_vector_extension = result.fetchone().has_vector
except Exception as e:
logger.warning(f"Could not check vector extension: {e}")
self._has_vector_extension = False
return self._has_vector_extension
def get_embedding_stats(self, table_name: str = "document_embeddings") -> Dict[str, Any]:
"""Get statistics about stored embeddings"""
try:
sql = f"""
SELECT
COUNT(*) as total_embeddings,
COUNT(DISTINCT source) as unique_sources,
AVG(CHAR_LENGTH(content)) as avg_content_length
FROM {table_name}
"""
with self.engine.connect() as conn:
result = conn.execute(text(sql))
row = result.fetchone()
return {
'total_embeddings': row.total_embeddings or 0,
'unique_sources': row.unique_sources or 0,
'avg_content_length': float(row.avg_content_length) if row.avg_content_length else 0.0
}
except Exception as e:
logger.error(f"Could not get embedding stats: {e}")
return {
'total_embeddings': 0,
'unique_sources': 0,
'avg_content_length': 0.0,
'error': str(e)
}