"""In-memory embedding cache for fast vector search.
This module provides an efficient in-memory cache of normalized embeddings
for fast cosine similarity computation. The cache loads all embeddings from
the search index database on first access and supports incremental updates.
Key optimizations:
- Pre-normalized embeddings at load time (cosine similarity = dot product)
- NumPy matrix operations for efficient similarity computation
- np.argpartition for O(n) top-k selection instead of O(n log n) sort
- Incremental append without full reload
- Validity checking via last_embedded_rowid
Memory usage: ~200MB for 33k messages (33k x 1536 dims x 4 bytes float32)
"""
import numpy as np
from ..db.search_index import get_search_index_connection, get_sync_metadata
class EmbeddingCache:
"""In-memory cache of normalized embeddings for fast cosine similarity."""
# Embedding dimension for text-embedding-3-small
EMBEDDING_DIM = 1536
def __init__(self, db_path: str | None = None):
"""Initialize the embedding cache.
Args:
db_path: Optional custom path to search index database.
If None, uses default path.
"""
self._embeddings: np.ndarray | None = None # Shape: (N, 1536), normalized
self._rowids: np.ndarray | None = None # Shape: (N,), int64
self._last_embedded_rowid: int = 0
self._db_path = db_path
def load(self) -> None:
"""Load all embeddings from database and normalize them.
Loads embeddings from the message_index table where embedding is not NULL
and not empty (length > 0). Converts BLOBs to numpy arrays and normalizes
them for fast cosine similarity via dot product.
"""
with get_search_index_connection(self._db_path) as conn:
# Get the current last_embedded_rowid from the database
last_embedded_str = get_sync_metadata(conn, "last_embedded_rowid")
self._last_embedded_rowid = (
int(last_embedded_str) if last_embedded_str else 0
)
# Query all embeddings (excluding empty blobs for skipped messages)
cursor = conn.execute(
"""
SELECT rowid, embedding
FROM message_index
WHERE embedding IS NOT NULL AND LENGTH(embedding) > 0
ORDER BY rowid ASC
"""
)
rowids: list[int] = []
embeddings: list[np.ndarray] = []
for row in cursor:
rowid = row["rowid"]
embedding_blob = row["embedding"]
# Convert BLOB to numpy array
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
# Verify dimension
if len(embedding) != self.EMBEDDING_DIM:
# Skip invalid embeddings
continue
rowids.append(rowid)
embeddings.append(embedding)
if embeddings:
# Stack into matrix: shape (N, 1536)
self._embeddings = np.stack(embeddings, axis=0)
# Normalize embeddings for fast cosine similarity
# cosine_similarity(a, b) = a·b / (||a|| * ||b||)
# With pre-normalized vectors: cosine_similarity = a·b
norms = np.linalg.norm(self._embeddings, axis=1, keepdims=True)
# Avoid division by zero
norms = np.maximum(norms, 1e-10)
self._embeddings = self._embeddings / norms
self._rowids = np.array(rowids, dtype=np.int64)
else:
# No embeddings yet - initialize empty arrays
self._embeddings = np.zeros((0, self.EMBEDDING_DIM), dtype=np.float32)
self._rowids = np.array([], dtype=np.int64)
def is_valid(self) -> bool:
"""Check if cache is still valid (last_embedded_rowid matches DB).
Returns:
True if cache is valid (no new embeddings in DB), False otherwise.
"""
if self._embeddings is None:
return False
with get_search_index_connection(self._db_path) as conn:
db_last_embedded_str = get_sync_metadata(conn, "last_embedded_rowid")
db_last_embedded = int(db_last_embedded_str) if db_last_embedded_str else 0
return self._last_embedded_rowid == db_last_embedded
def invalidate(self) -> None:
"""Clear the cache, will reload on next access."""
self._embeddings = None
self._rowids = None
self._last_embedded_rowid = 0
def append_embeddings(
self, new_rowids: list[int], new_embeddings: np.ndarray
) -> None:
"""Append new embeddings without full reload.
Pre-normalizes the new embeddings and appends them to the existing
cache. Updates the last_embedded_rowid to the maximum of the new rowids.
Args:
new_rowids: List of rowids for the new embeddings.
new_embeddings: NumPy array of shape (M, 1536) with new embeddings.
"""
if len(new_rowids) == 0:
return
# Ensure we have loaded data first
if self._embeddings is None:
self.load()
return # load() will have included these embeddings
# Normalize new embeddings
norms = np.linalg.norm(new_embeddings, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-10)
normalized_new = new_embeddings / norms
# Append to existing arrays
self._embeddings = np.vstack([self._embeddings, normalized_new])
self._rowids = np.concatenate(
[self._rowids, np.array(new_rowids, dtype=np.int64)]
)
# Update last_embedded_rowid
self._last_embedded_rowid = max(self._last_embedded_rowid, max(new_rowids))
def cosine_similarity(
self, query_embedding: np.ndarray, top_k: int = 100
) -> list[tuple[int, float]]:
"""Compute cosine similarity with all cached embeddings.
Uses dot product since embeddings are pre-normalized. Uses np.argpartition
for efficient O(n) top-k selection instead of full O(n log n) sort.
Args:
query_embedding: Query vector of shape (1536,). Will be normalized
if not already normalized.
top_k: Number of top results to return (default: 100).
Returns:
List of (rowid, similarity_score) tuples, sorted by score descending.
"""
# Ensure cache is loaded
embeddings = self.embeddings
rowids = self.rowids
if len(embeddings) == 0:
return []
# Normalize query embedding
query_norm = np.linalg.norm(query_embedding)
if query_norm < 1e-10:
return [] # Invalid query embedding
normalized_query = query_embedding / query_norm
# Compute cosine similarity via dot product (since embeddings are normalized)
# Shape: (N,)
similarities = embeddings @ normalized_query
# Efficient top-k selection using argpartition
# argpartition is O(n) vs O(n log n) for full sort
k = min(top_k, len(similarities))
if k == len(similarities):
# If we want all results, just sort
top_indices = np.argsort(similarities)[::-1]
else:
# Use argpartition for efficient top-k
# This partitions the array so that the k largest elements are at the end
partition_indices = np.argpartition(similarities, -k)[-k:]
# Sort only the top-k elements
top_indices = partition_indices[
np.argsort(similarities[partition_indices])[::-1]
]
# Build result list
results: list[tuple[int, float]] = []
for idx in top_indices:
rowid = int(rowids[idx])
score = float(similarities[idx])
results.append((rowid, score))
return results
@property
def embeddings(self) -> np.ndarray:
"""Get embeddings matrix, loading if needed.
Returns:
NumPy array of shape (N, 1536) with normalized embeddings.
"""
if self._embeddings is None or not self.is_valid():
self.load()
assert self._embeddings is not None # For type checker
return self._embeddings
@property
def rowids(self) -> np.ndarray:
"""Get rowid array, loading if needed.
Returns:
NumPy array of shape (N,) with rowids corresponding to embeddings.
"""
if self._rowids is None or not self.is_valid():
self.load()
assert self._rowids is not None # For type checker
return self._rowids
@property
def size(self) -> int:
"""Get the number of cached embeddings.
Triggers a cache load if not already loaded.
Returns:
Number of embeddings in the cache.
"""
return len(self.embeddings)
@property
def memory_usage_mb(self) -> float:
"""Get approximate memory usage of the cache in megabytes.
Returns:
Memory usage in MB, 0 if not loaded.
"""
if self._embeddings is None:
return 0.0
# Embeddings: N x 1536 x 4 bytes (float32)
# Rowids: N x 8 bytes (int64)
n = len(self._embeddings)
embeddings_bytes = n * self.EMBEDDING_DIM * 4
rowids_bytes = n * 8
return (embeddings_bytes + rowids_bytes) / (1024 * 1024)
# Global singleton instance
_cache: EmbeddingCache | None = None
def get_embedding_cache(db_path: str | None = None) -> EmbeddingCache:
"""Get the global embedding cache instance.
Creates a new singleton instance on first call. Subsequent calls
return the same instance.
Args:
db_path: Optional custom path to search index database.
Only used on first call to create the instance.
Returns:
Global EmbeddingCache instance.
"""
global _cache
if _cache is None:
_cache = EmbeddingCache(db_path)
return _cache
def reset_embedding_cache() -> None:
"""Reset the global embedding cache instance.
This is primarily useful for testing to ensure a clean state.
"""
global _cache
if _cache is not None:
_cache.invalidate()
_cache = None