"""
PostgreSQL vector client for semantic search functionality.
This module provides persistent vector database storage using PostgreSQL with pg-vector
extension, replacing ChromaDB for better performance and scalability.
"""
import logging
import psycopg2
from psycopg2.extras import RealDictCursor, execute_values
from psycopg2.pool import ThreadedConnectionPool
import pgvector.psycopg2
from typing import Dict, List, Any, Optional, Tuple
import json
from dataclasses import dataclass
from datetime import datetime
import threading
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingItem:
"""Data structure for embedding items."""
item_key: str
item_type: str
title: str
content: str
content_hash: str
embedding: List[float]
embedding_model: str
embedding_provider: str
metadata: Dict[str, Any]
content_type: str = 'metadata'
parent_item_key: Optional[str] = None
parent_attachment_key: Optional[str] = None
chunk_index: int = 0
chunk_total: int = 1
@dataclass
class SearchResult:
"""Search result data structure."""
item_key: str
similarity_score: float
title: str
content: str
metadata: Dict[str, Any]
embedding_model: str
embedding_provider: str
zotero_item: Optional[Dict[str, Any]] = None
class PostgreSQLVectorClient:
"""PostgreSQL-based vector database client with pg-vector support."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize PostgreSQL vector client.
Args:
config: Database configuration dict with keys:
- host, port, database, username, password
- schema, pool_size, connection_timeout
"""
self.config = config
self.host = config["host"]
self.port = config["port"]
self.database = config["database"]
self.username = config["username"]
self.password = config["password"]
self.schema = config.get("schema", "public")
# Connection pooling
self.pool_size = config.get("pool_size", 5)
self.max_overflow = config.get("max_overflow", 10)
self.connection_timeout = config.get("connection_timeout", 30)
self._connection_pool = None
self._pool_lock = threading.Lock()
# Initialize connection pool
self._initialize_connection_pool()
def _get_connection_string(self) -> str:
"""Get PostgreSQL connection string."""
return (
f"host={self.host} "
f"port={self.port} "
f"dbname={self.database} "
f"user={self.username} "
f"password={self.password} "
f"connect_timeout={self.connection_timeout}"
)
def _initialize_connection_pool(self) -> None:
"""Initialize connection pool."""
try:
conn_string = self._get_connection_string()
self._connection_pool = ThreadedConnectionPool(
minconn=1,
maxconn=self.pool_size + self.max_overflow,
dsn=conn_string
)
logger.info("PostgreSQL connection pool initialized")
except Exception as e:
logger.error(f"Error initializing connection pool: {e}")
raise
def get_connection(self) -> psycopg2.extensions.connection:
"""Get a connection from the pool."""
if not self._connection_pool:
raise RuntimeError("Connection pool not initialized")
try:
conn = self._connection_pool.getconn()
# Register vector extension for this connection
pgvector.psycopg2.register_vector(conn)
return conn
except Exception as e:
logger.error(f"Error getting connection from pool: {e}")
raise
def return_connection(self, conn: psycopg2.extensions.connection) -> None:
"""Return a connection to the pool."""
if self._connection_pool:
self._connection_pool.putconn(conn)
def close_pool(self) -> None:
"""Close the connection pool."""
if self._connection_pool:
self._connection_pool.closeall()
self._connection_pool = None
def test_connection(self) -> bool:
"""Test database connection."""
try:
conn = self.get_connection()
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
cursor.close()
self.return_connection(conn)
return result is not None
except Exception as e:
logger.error(f"Connection test failed: {e}")
return False
def upsert_embeddings(self, items: List[EmbeddingItem]) -> Dict[str, int]:
"""
Insert or update embeddings in batch.
Args:
items: List of EmbeddingItem objects
Returns:
Statistics dict with counts of added/updated items
"""
if not items:
return {"added": 0, "updated": 0, "errors": 0}
conn = self.get_connection()
stats = {"added": 0, "updated": 0, "errors": 0}
try:
cursor = conn.cursor()
# Prepare data for batch insert
values = []
for item in items:
values.append((
item.item_key,
item.item_type,
item.title,
item.content,
item.content_hash,
item.embedding,
item.embedding_model,
item.embedding_provider,
json.dumps(item.metadata),
item.content_type,
item.parent_item_key,
item.parent_attachment_key,
item.chunk_index,
item.chunk_total
))
# Use ON CONFLICT to handle upserts
upsert_query = """
INSERT INTO zotero_embeddings
(item_key, item_type, title, content, content_hash, embedding,
embedding_model, embedding_provider, metadata, content_type,
parent_item_key, parent_attachment_key, chunk_index, chunk_total)
VALUES %s
ON CONFLICT (item_key) DO UPDATE SET
item_type = EXCLUDED.item_type,
title = EXCLUDED.title,
content = EXCLUDED.content,
content_hash = EXCLUDED.content_hash,
embedding = EXCLUDED.embedding,
embedding_model = EXCLUDED.embedding_model,
embedding_provider = EXCLUDED.embedding_provider,
metadata = EXCLUDED.metadata,
content_type = EXCLUDED.content_type,
parent_item_key = EXCLUDED.parent_item_key,
parent_attachment_key = EXCLUDED.parent_attachment_key,
chunk_index = EXCLUDED.chunk_index,
chunk_total = EXCLUDED.chunk_total,
updated_at = CURRENT_TIMESTAMP
RETURNING (xmax = 0) AS inserted
"""
# Execute batch upsert
results = execute_values(
cursor, upsert_query, values,
template=None, page_size=100, fetch=True
)
# Count insertions vs updates
if results:
for result in results:
try:
if result and len(result) > 0:
if result[0]: # xmax = 0 means insert
stats["added"] += 1
else: # xmax > 0 means update
stats["updated"] += 1
except (IndexError, TypeError) as e:
logger.warning(f"Could not parse upsert result: {result}, error: {e}")
# Assume it was an update if we can't determine
stats["updated"] += 1
conn.commit()
cursor.close()
logger.info(f"Upserted {len(items)} embeddings: "
f"{stats['added']} added, {stats['updated']} updated")
except Exception as e:
conn.rollback()
stats["errors"] = len(items)
logger.error(f"Error upserting embeddings: {e}")
raise
finally:
self.return_connection(conn)
return stats
def search_similar(self,
query_vector: List[float],
limit: int = 10,
filters: Optional[Dict[str, Any]] = None,
similarity_threshold: float = 0.0,
similarity_metric: str = "cosine") -> List[SearchResult]:
"""
Perform vector similarity search with optional filters.
Args:
query_vector: Query embedding vector
limit: Maximum number of results
filters: Metadata filter conditions (JSONB queries)
similarity_threshold: Minimum similarity score (0.0 to 1.0)
similarity_metric: Similarity metric ('cosine', 'l2', or 'inner_product')
Returns:
List of SearchResult objects with similarity scores
"""
if not query_vector:
logger.warning("Empty query vector provided to search_similar")
return []
# Validate query vector
if not isinstance(query_vector, list) or len(query_vector) == 0:
logger.error(f"Invalid query vector: {type(query_vector)}, length: {len(query_vector) if hasattr(query_vector, '__len__') else 'N/A'}")
return []
# Check for all-zero vector (might indicate embedding issues)
if all(x == 0.0 for x in query_vector):
logger.warning("Query vector contains all zeros - may produce poor results")
logger.debug(f"Performing similarity search with vector of length {len(query_vector)}, limit {limit}")
conn = self.get_connection()
try:
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Choose similarity operator based on metric
if similarity_metric == "cosine":
similarity_op = "<=>"
# For cosine distance, convert to similarity (1 - distance)
similarity_expr = "(1 - (embedding <=> %s::vector))"
elif similarity_metric == "l2":
similarity_op = "<->"
# For L2 distance, use negative distance as similarity
similarity_expr = "(-(embedding <-> %s::vector))"
else: # inner_product
similarity_op = "<#>"
# For inner product, use the actual value
similarity_expr = "(embedding <#> %s::vector)"
# Build WHERE clause for filters
where_conditions = []
# Start params with query_vector for similarity_expr in SELECT
params = [query_vector]
if filters:
for key, value in filters.items():
if key == "item_type":
where_conditions.append("item_type = %s")
params.append(value)
elif key == "embedding_model":
where_conditions.append("embedding_model = %s")
params.append(value)
elif key == "embedding_provider":
where_conditions.append("embedding_provider = %s")
params.append(value)
else:
# Handle JSONB metadata filters
if isinstance(value, str):
where_conditions.append("metadata->%s = %s")
params.extend([key, json.dumps(value)])
elif isinstance(value, (list, dict)):
where_conditions.append("metadata->%s @> %s")
params.extend([key, json.dumps(value)])
# Add similarity threshold filter - this adds another query_vector parameter
if similarity_threshold > 0.0:
where_conditions.append(f"{similarity_expr} >= %s")
params.extend([query_vector, similarity_threshold]) # WHERE clause needs its own query_vector
where_clause = ""
if where_conditions:
where_clause = "WHERE " + " AND ".join(where_conditions)
# Build complete query
# Parameters needed:
# 1. query_vector for SELECT similarity_expr
# 2. query_vector for WHERE similarity_expr (if threshold > 0)
# 3. similarity_threshold (if threshold > 0)
# 4. query_vector for ORDER BY
# 5. limit
query = f"""
SELECT
item_key,
item_type,
title,
content,
metadata,
embedding_model,
embedding_provider,
{similarity_expr} AS similarity_score
FROM zotero_embeddings
{where_clause}
ORDER BY embedding {similarity_op} %s::vector
LIMIT %s
"""
# Add query vector for ORDER BY and limit to params
params.extend([query_vector, limit])
# DEBUG: Log parameter count and query
logger.debug(f"SQL Query: {query}")
logger.debug(f"Parameter count: {len(params)}")
logger.debug(f"Parameter types: {[type(p).__name__ for p in params]}")
logger.debug(f"Query vector length: {len(query_vector) if isinstance(query_vector, list) else 'Not a list'}")
logger.debug(f"Expected parameters: SELECT(%s), WHERE(%s,%s), ORDER_BY(%s), LIMIT(%s) = 5 total")
cursor.execute(query, params)
results = cursor.fetchall()
logger.debug(f"Database query returned {len(results) if results else 0} rows")
# Convert to SearchResult objects
search_results = []
for i, row in enumerate(results):
try:
if not row:
logger.warning(f"Row {i} is None or empty, skipping")
continue
# Validate required fields
required_fields = ['item_key', 'similarity_score', 'title', 'content', 'metadata', 'embedding_model', 'embedding_provider']
missing_fields = [field for field in required_fields if field not in row]
if missing_fields:
logger.error(f"Row {i} missing required fields: {missing_fields}")
continue
result = SearchResult(
item_key=row['item_key'],
similarity_score=float(row['similarity_score']) if row['similarity_score'] is not None else 0.0,
title=row['title'] or '',
content=row['content'] or '',
metadata=row['metadata'] or {},
embedding_model=row['embedding_model'] or '',
embedding_provider=row['embedding_provider'] or ''
)
search_results.append(result)
except (KeyError, ValueError, TypeError) as e:
logger.error(f"Error processing search result row {i}: {e}")
logger.debug(f"Problematic row: {row}")
continue
except Exception as e:
logger.error(f"Unexpected error processing row {i}: {e}")
continue
cursor.close()
logger.debug(f"Successfully created {len(search_results)} SearchResult objects")
return search_results
except Exception as e:
logger.error(f"Error performing similarity search: {e}")
raise
finally:
self.return_connection(conn)
def delete_items(self, item_keys: List[str]) -> int:
"""
Delete items by Zotero keys.
Args:
item_keys: List of Zotero item keys to delete
Returns:
Number of items deleted
"""
if not item_keys:
return 0
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM zotero_embeddings WHERE item_key = ANY(%s)",
(item_keys,)
)
deleted_count = cursor.rowcount
conn.commit()
cursor.close()
logger.info(f"Deleted {deleted_count} items")
return deleted_count
except Exception as e:
conn.rollback()
logger.error(f"Error deleting items: {e}")
raise
finally:
self.return_connection(conn)
def get_item_count(self) -> int:
"""Get total number of indexed items."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM zotero_embeddings")
result = cursor.fetchone()
cursor.close()
if result and len(result) > 0:
return result[0]
else:
logger.warning("Could not get item count, returning 0")
return 0
except Exception as e:
logger.error(f"Error getting item count: {e}")
return 0
finally:
self.return_connection(conn)
def check_item_exists(self, item_key: str, content_hash: Optional[str] = None) -> bool:
"""
Check if item exists, optionally with specific content hash.
Args:
item_key: Zotero item key
content_hash: Optional content hash to check for changes
Returns:
True if item exists (and content hash matches if provided)
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
if content_hash:
cursor.execute(
"SELECT 1 FROM zotero_embeddings WHERE item_key = %s AND content_hash = %s",
(item_key, content_hash)
)
else:
cursor.execute(
"SELECT 1 FROM zotero_embeddings WHERE item_key = %s",
(item_key,)
)
exists = cursor.fetchone() is not None
cursor.close()
return exists
except Exception as e:
logger.error(f"Error checking item existence: {e}")
return False
finally:
self.return_connection(conn)
def get_embedding_info(self, item_key: str) -> Optional[Dict[str, Any]]:
"""
Get embedding information for a specific item.
Args:
item_key: Zotero item key
Returns:
Dict with embedding info or None if not found
"""
conn = self.get_connection()
try:
cursor = conn.cursor(cursor_factory=RealDictCursor)
cursor.execute("""
SELECT item_key, item_type, title, content_hash,
embedding_model, embedding_provider, created_at, updated_at
FROM zotero_embeddings
WHERE item_key = %s
""", (item_key,))
result = cursor.fetchone()
cursor.close()
return dict(result) if result else None
except Exception as e:
logger.error(f"Error getting embedding info: {e}")
return None
finally:
self.return_connection(conn)
def get_database_status(self) -> Dict[str, Any]:
"""Get database status and statistics."""
conn = self.get_connection()
try:
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Basic statistics
cursor.execute("""
SELECT
COUNT(*) as total_items,
COUNT(DISTINCT item_type) as unique_types,
COUNT(DISTINCT embedding_model) as unique_models,
COUNT(DISTINCT embedding_provider) as unique_providers
FROM zotero_embeddings
""")
stats = cursor.fetchone()
# Model/provider breakdown
cursor.execute("""
SELECT embedding_provider, embedding_model, COUNT(*) as count
FROM zotero_embeddings
GROUP BY embedding_provider, embedding_model
ORDER BY count DESC
""")
model_breakdown = cursor.fetchall()
# Recent activity
cursor.execute("""
SELECT
DATE_TRUNC('day', created_at) as date,
COUNT(*) as items_added
FROM zotero_embeddings
WHERE created_at > CURRENT_DATE - INTERVAL '7 days'
GROUP BY DATE_TRUNC('day', created_at)
ORDER BY date DESC
""")
recent_activity = cursor.fetchall()
cursor.close()
return {
"total_items": stats['total_items'],
"unique_types": stats['unique_types'],
"unique_models": stats['unique_models'],
"unique_providers": stats['unique_providers'],
"model_breakdown": [dict(row) for row in model_breakdown],
"recent_activity": [dict(row) for row in recent_activity],
"connection_info": {
"host": self.host,
"port": self.port,
"database": self.database,
"pool_size": self.pool_size
}
}
except Exception as e:
logger.error(f"Error getting database status: {e}")
return {"error": str(e)}
finally:
self.return_connection(conn)
def vacuum_and_reindex(self) -> Dict[str, Any]:
"""Optimize database performance."""
conn = self.get_connection()
try:
# Note: VACUUM cannot run inside a transaction
conn.autocommit = True
cursor = conn.cursor()
# Vacuum and analyze
cursor.execute("VACUUM ANALYZE zotero_embeddings")
# Reindex vector indexes
cursor.execute("REINDEX INDEX idx_zotero_embedding_cosine")
cursor.close()
logger.info("Database optimization completed")
return {"status": "completed", "message": "Database optimized successfully"}
except Exception as e:
logger.error(f"Error during database optimization: {e}")
return {"status": "error", "message": str(e)}
finally:
conn.autocommit = False
self.return_connection(conn)
def create_vector_client(config: Dict[str, Any]) -> PostgreSQLVectorClient:
"""
Create a PostgreSQL vector client from configuration.
Args:
config: Database configuration dictionary
Returns:
Configured PostgreSQLVectorClient instance
"""
return PostgreSQLVectorClient(config)