vector_search.py•12.7 kB
"""
Vector Search Module
Provides semantic search functionality using vector embeddings for document similarity.
Extracted from the deprecated LLMDatabaseRouter for integration into the main MCP system.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.engine import Engine
from postgres_integration import PostgreSQLIntegration
logger = logging.getLogger(__name__)
class VectorSearchEngine:
"""Handles vector-based semantic search operations on document embeddings"""
def __init__(self, engine: Engine, db_key: str = "default"):
self.engine = engine
self.db_key = db_key
self.postgres = PostgreSQLIntegration(engine, db_key)
def semantic_search(self, question: str, table_filter: Optional[str] = None,
fk_filter: Optional[Dict] = None, k: int = 10) -> List[Dict[str, Any]]:
"""Search document embeddings using vector similarity"""
try:
# First, generate an embedding for the user's question
question_embedding = self._generate_embedding(question)
if not question_embedding:
# Fallback to text search if embedding generation fails
return self._text_search_fallback(question, table_filter, fk_filter, k)
with self.engine.connect() as connection:
# Use vector similarity search with cosine distance
sql_parts = ["""
SELECT
table_name,
pk_json,
snippet,
created_at,
(embedding <=> :question_embedding::vector) AS distance
FROM doc_embeddings
WHERE 1=1
"""]
params = {'question_embedding': question_embedding}
if table_filter:
sql_parts.append("AND table_name = :table_filter")
params['table_filter'] = table_filter
if fk_filter:
sql_parts.append("AND pk_json @> :fk_filter::jsonb")
params['fk_filter'] = json.dumps(fk_filter)
# Order by similarity (lower distance = more similar)
sql_parts.append("ORDER BY distance ASC LIMIT :k")
params['k'] = k
final_sql = " ".join(sql_parts)
result = connection.execute(text(final_sql), params)
rows = []
for row in result:
row_dict = dict(row._mapping)
# Remove the distance from the final result (internal use only)
row_dict.pop('distance', None)
rows.append(row_dict)
return rows
except Exception as e:
logger.error(f"Error in vector search for {self.db_key}: {e}")
# Fallback to text search
return self._text_search_fallback(question, table_filter, fk_filter, k)
def _generate_embedding(self, text_input: str) -> Optional[str]:
"""Generate embedding for text using available methods"""
try:
# First check if the embedding column exists in this database
with self.engine.connect() as connection:
embedding_col_exists = connection.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'doc_embeddings'
AND column_name = 'embedding'
AND table_schema = 'public'
);
""")).scalar()
if not embedding_col_exists:
# No embedding column, can't do vector search
return None
# Method 1: Try to use an existing similar embedding as a proxy
# This is a simplified approach - in production you'd want to integrate
# with an actual embedding service (OpenAI, Anthropic, etc.)
with self.engine.connect() as connection:
# Look for existing embeddings with similar text patterns
similar_result = connection.execute(text("""
SELECT embedding FROM doc_embeddings
WHERE snippet ILIKE :pattern
LIMIT 1
"""), {'pattern': f'%{text_input[:50]}%'})
row = similar_result.fetchone()
if row and row.embedding:
return str(row.embedding)
except Exception as e:
logger.warning(f"Embedding generation failed: {e}")
return None
return None
def _text_search_fallback(self, question: str, table_filter: Optional[str] = None,
fk_filter: Optional[Dict] = None, k: int = 10) -> List[Dict[str, Any]]:
"""Fallback to text search when vector search is not available"""
try:
with self.engine.connect() as connection:
sql_parts = ["SELECT table_name, pk_json, snippet, created_at FROM doc_embeddings WHERE 1=1"]
params = {}
if table_filter:
sql_parts.append("AND table_name = :table_filter")
params['table_filter'] = table_filter
if fk_filter:
sql_parts.append("AND pk_json @> :fk_filter::jsonb")
params['fk_filter'] = json.dumps(fk_filter)
# Improved text search with keyword extraction and broader matching
search_terms = self._extract_search_terms(question)
if search_terms:
# Build OR conditions for multiple search terms
term_conditions = []
for i, term in enumerate(search_terms[:5]): # Limit to 5 terms
param_name = f'search_term_{i}'
term_conditions.append(f"snippet ILIKE :{param_name}")
params[param_name] = f'%{term}%'
if term_conditions:
sql_parts.append(f"AND ({' OR '.join(term_conditions)})")
# Order by creation date as a proxy for relevance
sql_parts.append("ORDER BY created_at DESC LIMIT :k")
params['k'] = k
final_sql = " ".join(sql_parts)
result = connection.execute(text(final_sql), params)
rows = []
for row in result:
rows.append(dict(row._mapping))
return rows
except Exception as e:
logger.error(f"Error in text search fallback for {self.db_key}: {e}")
return []
def _extract_search_terms(self, question: str) -> List[str]:
"""Extract meaningful search terms from a question"""
# Simple keyword extraction - in production you might want more sophisticated NLP
import re
# Remove common stop words and extract meaningful terms
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
'should', 'may', 'might', 'must', 'can', 'what', 'when', 'where',
'why', 'how', 'who', 'which', 'that', 'this', 'these', 'those'}
# Extract words (3+ characters) and filter out stop words
words = re.findall(r'\b[a-zA-Z]{3,}\b', question.lower())
meaningful_words = [word for word in words if word not in stop_words]
return meaningful_words[:10] # Return top 10 terms
def populate_sample_docs(self, table_configs: Optional[List[Dict]] = None) -> Dict[str, Any]:
"""Populate document embeddings with sample data from existing tables"""
try:
# Default configuration for common table patterns
if not table_configs:
table_configs = [
{
'table': 'descriptions',
'pk_cols': ['description_id', 'trip_id'],
'content_col': 'note',
'join_info': """
LEFT JOIN trips t ON t.trip_id = d.trip_id
LEFT JOIN personnel p ON p.personnel_id = d.personnel_id
""",
'snippet_template': "Trip Note — {author_name} — {start_date} to {end_date} — {note}"
},
{
'table': 'personnel',
'pk_cols': ['personnel_id'],
'content_col': 'name',
'snippet_template': "Personnel — {name} — Position: {position} — Age: {age}"
}
]
with self.engine.begin() as connection:
# Clear existing doc embeddings for these tables
table_names = [config['table'] for config in table_configs]
placeholders = ','.join([f"'{name}'" for name in table_names])
connection.execute(text(f"DELETE FROM doc_embeddings WHERE table_name IN ({placeholders})"))
doc_count = 0
for config in table_configs:
try:
table = config['table']
pk_cols = config['pk_cols']
# Build dynamic query based on configuration
if 'join_info' in config:
query = f"""
SELECT d.*, {config.get('extra_cols', '')}
FROM {table} d
{config['join_info']}
ORDER BY d.created_at DESC
LIMIT 100
"""
else:
query = f"SELECT * FROM {table}"
result = connection.execute(text(query))
for row in result:
row_dict = dict(row._mapping)
# Generate snippet using template
snippet = self._generate_snippet(config['snippet_template'], row_dict)
# Build primary key JSON
pk_json = {col: str(row_dict.get(col, '')) for col in pk_cols}
connection.execute(text("""
INSERT INTO doc_embeddings (table_name, pk_json, snippet, created_at)
VALUES (:table_name, :pk_json, :snippet, :created_at)
"""), {
'table_name': table,
'pk_json': json.dumps(pk_json),
'snippet': snippet,
'created_at': row_dict.get('created_at')
})
doc_count += 1
except Exception as e:
logger.warning(f"Could not populate docs from table {config['table']}: {e}")
continue
return {
"success": True,
"message": f"Populated {doc_count} document embeddings for {self.db_key}",
"documents_created": doc_count
}
except Exception as e:
logger.error(f"Error populating sample docs for {self.db_key}: {e}")
return {
"success": False,
"error": str(e),
"documents_created": 0
}
def _generate_snippet(self, template: str, row_data: Dict) -> str:
"""Generate a snippet using template and row data with safe defaults"""
try:
# Provide safe defaults for missing keys
safe_data = {}
for key, value in row_data.items():
safe_data[key] = value if value is not None else 'Unknown'
# Truncate long content
if 'note' in safe_data and len(str(safe_data['note'])) > 400:
safe_data['note'] = str(safe_data['note'])[:400] + '...'
return template.format(**safe_data)
except KeyError as e:
# Fallback if template formatting fails
return f"Document from {row_data.get('table_name', 'unknown table')}"