"""Vector similarity search using cached embeddings.
This module provides semantic search functionality using OpenAI embeddings and
efficient cosine similarity computation. It uses the in-memory embedding cache
for fast batch similarity operations and supports metadata pre-filtering.
Key features:
- Query embedding generation via OpenAI API
- Metadata pre-filtering before vector search
- Cached embedding matrix for efficient batch cosine similarity
- Returns top-k results sorted by similarity score
Similarity scores are cosine similarity in range [0, 1] where 1 is most similar.
"""
import numpy as np
from dataclasses import dataclass
from typing import Any, Optional, List, Tuple
from .embedding_cache import get_embedding_cache
from .embeddings import generate_embeddings, get_openai_client
from .fts5 import FTSFilters # Reuse the same filter dataclass
from ..db.search_index import get_search_index_connection
@dataclass
class VectorResult:
"""Result from vector similarity search.
Attributes:
rowid: Message ROWID from chat.db
score: Cosine similarity score (higher is better, 0-1 range)
"""
rowid: int
score: float # Cosine similarity (higher is better, 0-1 range)
def get_query_embedding(query: str) -> Optional[np.ndarray]:
"""Get embedding for query text from OpenAI, normalized.
Args:
query: Query text to embed.
Returns:
Normalized embedding array of shape (1536,), or None if:
- OpenAI API key is not set
- API call fails
- Query is empty
"""
if not query or not query.strip():
return None
# Check if OpenAI API key is available
client = get_openai_client()
if client is None:
return None
try:
# Generate embedding for query
embeddings = generate_embeddings([query], client=client)
if not embeddings or len(embeddings) == 0:
return None
embedding = embeddings[0]
# Normalize the embedding
norm = float(np.linalg.norm(embedding))
if norm < 1e-10:
return None # Invalid embedding (zero vector)
normalized: np.ndarray = embedding / norm
return normalized
except Exception:
# Failed to generate embedding (API error, etc.)
return None
def get_filtered_rowids(filters: Optional[FTSFilters]) -> Optional[List[int]]:
"""Get rowids that match the metadata filters.
Args:
filters: Optional metadata filters to apply.
Returns:
List of rowids matching the filters, or None if no filters
(meaning all rowids are candidates).
"""
if filters is None:
return None
# Build WHERE clause components
where_clauses: List[str] = []
params: List[Any] = []
if filters.sender is not None:
where_clauses.append("sender = ?")
params.append(filters.sender)
if filters.chat_id is not None:
where_clauses.append("chat_id = ?")
params.append(filters.chat_id)
if filters.after_date is not None:
where_clauses.append("date_coredata >= ?")
params.append(filters.after_date)
if filters.before_date is not None:
where_clauses.append("date_coredata <= ?")
params.append(filters.before_date)
if filters.service is not None:
where_clauses.append("service = ?")
params.append(filters.service)
# If no filters were provided, return None
if not where_clauses:
return None
# Query message_index with filter conditions
where_clause = " AND ".join(where_clauses)
sql = f"SELECT rowid FROM message_index WHERE {where_clause}"
with get_search_index_connection() as conn:
cursor = conn.execute(sql, params)
rows = cursor.fetchall()
return [row[0] for row in rows]
def vector_search(
query: str,
filters: Optional[FTSFilters] = None,
limit: int = 100
) -> Tuple[List[VectorResult], int]:
"""Perform vector similarity search.
This function performs semantic search using embeddings:
1. Generates query embedding from OpenAI
2. Applies optional metadata pre-filtering
3. Computes cosine similarity with cached embeddings
4. Returns top results sorted by similarity
Args:
query: Search query text
filters: Optional metadata filters (sender, chat_id, date range, service)
limit: Max results to return (default: 100)
Returns:
Tuple of (results, total_matching_count) where:
- results: List of VectorResult objects sorted by score descending
- total_matching_count: Total number of results before limit applied
Example:
>>> # Simple semantic search
>>> results, total = vector_search("meeting tomorrow")
>>> # Semantic search with filters
>>> filters = FTSFilters(sender="+1234567890", after_date=726000000000000)
>>> results, total = vector_search("dinner plans", filters=filters)
"""
# Get query embedding
query_embedding = get_query_embedding(query)
if query_embedding is None:
# No API key, empty query, or API failure
return [], 0
# Get embedding cache
cache = get_embedding_cache()
# Check if cache is empty
if cache.size == 0:
return [], 0
# Get filtered rowids if filters are provided
allowed_rowids: Optional[set[int]] = None
if filters is not None:
filtered_list = get_filtered_rowids(filters)
if filtered_list is not None:
allowed_rowids = set(filtered_list)
# If no messages match the filters, return empty results
if len(allowed_rowids) == 0:
return [], 0
# Compute cosine similarity for all embeddings
# We request more results than limit if we have filters, since we'll need
# to filter them down
if allowed_rowids is not None:
# Request up to 10x limit to ensure we have enough after filtering
# Cap at cache size to avoid unnecessary computation
top_k = min(limit * 10, cache.size)
else:
top_k = limit
all_similarities = cache.cosine_similarity(query_embedding, top_k=top_k)
# Filter results by allowed rowids if needed
if allowed_rowids is not None:
filtered_results = [
(rowid, score)
for rowid, score in all_similarities
if rowid in allowed_rowids
]
# If we didn't get enough results and haven't checked all embeddings,
# we might need more. But for simplicity, we'll just return what we have.
# In practice, 10x should be enough for most filter selectivity.
results_tuples = filtered_results[:limit]
total_count = len(filtered_results)
else:
results_tuples = all_similarities[:limit]
total_count = len(all_similarities)
# Convert to VectorResult objects
results = [
VectorResult(rowid=rowid, score=score)
for rowid, score in results_tuples
]
return results, total_count