import json
import logging
import uuid
from typing import Any
import numpy as np
from pydantic import BaseModel
try:
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster
except ImportError:
raise ImportError(
"Apache Cassandra vector store requires cassandra-driver. "
"Please install it using 'pip install cassandra-driver'"
)
from selfmemory.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: str | None
score: float | None
payload: dict | None
class CassandraDB(VectorStoreBase):
def __init__(
self,
contact_points: list[str],
port: int = 9042,
username: str | None = None,
password: str | None = None,
keyspace: str = "mem0",
collection_name: str = "memories",
embedding_model_dims: int = 1536,
secure_connect_bundle: str | None = None,
protocol_version: int = 4,
load_balancing_policy: Any | None = None,
):
"""
Initialize the Apache Cassandra vector store.
Args:
contact_points (List[str]): List of contact point addresses (e.g., ['127.0.0.1'])
port (int): Cassandra port (default: 9042)
username (str, optional): Database username
password (str, optional): Database password
keyspace (str): Keyspace name (default: "mem0")
collection_name (str): Table name (default: "memories")
embedding_model_dims (int): Dimension of the embedding vector (default: 1536)
secure_connect_bundle (str, optional): Path to secure connect bundle for Astra DB
protocol_version (int): CQL protocol version (default: 4)
load_balancing_policy (Any, optional): Custom load balancing policy
"""
self.contact_points = contact_points
self.port = port
self.username = username
self.password = password
self.keyspace = keyspace
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.secure_connect_bundle = secure_connect_bundle
self.protocol_version = protocol_version
self.load_balancing_policy = load_balancing_policy
# Initialize connection
self.cluster = None
self.session = None
self._setup_connection()
# Create keyspace and table if they don't exist
self._create_keyspace()
self._create_table()
def _setup_connection(self):
"""Setup Cassandra cluster connection."""
try:
# Setup authentication
auth_provider = None
if self.username and self.password:
auth_provider = PlainTextAuthProvider(
username=self.username, password=self.password
)
# Connect to Astra DB using secure connect bundle
if self.secure_connect_bundle:
self.cluster = Cluster(
cloud={"secure_connect_bundle": self.secure_connect_bundle},
auth_provider=auth_provider,
protocol_version=self.protocol_version,
)
else:
# Connect to standard Cassandra cluster
cluster_kwargs = {
"contact_points": self.contact_points,
"port": self.port,
"protocol_version": self.protocol_version,
}
if auth_provider:
cluster_kwargs["auth_provider"] = auth_provider
if self.load_balancing_policy:
cluster_kwargs["load_balancing_policy"] = self.load_balancing_policy
self.cluster = Cluster(**cluster_kwargs)
self.session = self.cluster.connect()
logger.info("Successfully connected to Cassandra cluster")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}")
raise
def _create_keyspace(self):
"""Create keyspace if it doesn't exist."""
try:
# Use SimpleStrategy for single datacenter, NetworkTopologyStrategy for production
query = f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
"""
self.session.execute(query)
self.session.set_keyspace(self.keyspace)
logger.info(f"Keyspace '{self.keyspace}' is ready")
except Exception as e:
logger.error(f"Failed to create keyspace: {e}")
raise
def _create_table(self):
"""Create table with vector column if it doesn't exist."""
try:
# Create table with vector stored as list<float> and payload as text (JSON)
query = f"""
CREATE TABLE IF NOT EXISTS {self.keyspace}.{self.collection_name} (
id text PRIMARY KEY,
vector list<float>,
payload text
)
"""
self.session.execute(query)
logger.info(f"Table '{self.collection_name}' is ready")
except Exception as e:
logger.error(f"Failed to create table: {e}")
raise
def create_col(
self, name: str = None, vector_size: int = None, distance: str = "cosine"
):
"""
Create a new collection (table in Cassandra).
Args:
name (str, optional): Collection name (uses self.collection_name if not provided)
vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
distance (str): Distance metric (cosine, euclidean, dot_product)
"""
table_name = name or self.collection_name
dims = vector_size or self.embedding_model_dims
try:
query = f"""
CREATE TABLE IF NOT EXISTS {self.keyspace}.{table_name} (
id text PRIMARY KEY,
vector list<float>,
payload text
)
"""
self.session.execute(query)
logger.info(
f"Created collection '{table_name}' with vector dimension {dims}"
)
except Exception as e:
logger.error(f"Failed to create collection: {e}")
raise
def insert(
self,
vectors: list[list[float]],
payloads: list[dict] | None = None,
ids: list[str] | None = None,
):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(
f"Inserting {len(vectors)} vectors into collection {self.collection_name}"
)
if payloads is None:
payloads = [{}] * len(vectors)
if ids is None:
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
try:
query = f"""
INSERT INTO {self.keyspace}.{self.collection_name} (id, vector, payload)
VALUES (?, ?, ?)
"""
prepared = self.session.prepare(query)
for vector, payload, vec_id in zip(vectors, payloads, ids, strict=False):
self.session.execute(prepared, (vec_id, vector, json.dumps(payload)))
except Exception as e:
logger.error(f"Failed to insert vectors: {e}")
raise
def search(
self,
query: str,
vectors: list[float],
limit: int = 5,
filters: dict | None = None,
) -> list[OutputData]:
"""
Search for similar vectors using cosine similarity.
Args:
query (str): Query string (not used in vector search)
vectors (List[float]): Query vector
limit (int): Number of results to return
filters (Dict, optional): Filters to apply to the search
Returns:
List[OutputData]: Search results
"""
try:
# Fetch all vectors (in production, you'd want pagination or filtering)
query_cql = f"""
SELECT id, vector, payload
FROM {self.keyspace}.{self.collection_name}
"""
rows = self.session.execute(query_cql)
# Calculate cosine similarity in Python
query_vec = np.array(vectors)
scored_results = []
for row in rows:
if not row.vector:
continue
vec = np.array(row.vector)
# Cosine similarity
similarity = np.dot(query_vec, vec) / (
np.linalg.norm(query_vec) * np.linalg.norm(vec)
)
distance = 1 - similarity
# Apply filters if provided
if filters:
try:
payload = json.loads(row.payload) if row.payload else {}
match = all(payload.get(k) == v for k, v in filters.items())
if not match:
continue
except json.JSONDecodeError:
continue
scored_results.append((row.id, distance, row.payload))
# Sort by distance and limit
scored_results.sort(key=lambda x: x[1])
scored_results = scored_results[:limit]
return [
OutputData(
id=r[0], score=float(r[1]), payload=json.loads(r[2]) if r[2] else {}
)
for r in scored_results
]
except Exception as e:
logger.error(f"Search failed: {e}")
raise
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
try:
query = f"""
DELETE FROM {self.keyspace}.{self.collection_name}
WHERE id = ?
"""
prepared = self.session.prepare(query)
self.session.execute(prepared, (vector_id,))
logger.info(f"Deleted vector with id: {vector_id}")
except Exception as e:
logger.error(f"Failed to delete vector: {e}")
raise
def update(
self,
vector_id: str,
vector: list[float] | None = None,
payload: dict | None = None,
):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
try:
if vector is not None:
query = f"""
UPDATE {self.keyspace}.{self.collection_name}
SET vector = ?
WHERE id = ?
"""
prepared = self.session.prepare(query)
self.session.execute(prepared, (vector, vector_id))
if payload is not None:
query = f"""
UPDATE {self.keyspace}.{self.collection_name}
SET payload = ?
WHERE id = ?
"""
prepared = self.session.prepare(query)
self.session.execute(prepared, (json.dumps(payload), vector_id))
logger.info(f"Updated vector with id: {vector_id}")
except Exception as e:
logger.error(f"Failed to update vector: {e}")
raise
def get(self, vector_id: str) -> OutputData | None:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
OutputData: Retrieved vector or None if not found
"""
try:
query = f"""
SELECT id, vector, payload
FROM {self.keyspace}.{self.collection_name}
WHERE id = ?
"""
prepared = self.session.prepare(query)
row = self.session.execute(prepared, (vector_id,)).one()
if not row:
return None
return OutputData(
id=row.id,
score=None,
payload=json.loads(row.payload) if row.payload else {},
)
except Exception as e:
logger.error(f"Failed to get vector: {e}")
return None
def list_cols(self) -> list[str]:
"""
List all collections (tables in the keyspace).
Returns:
List[str]: List of collection names
"""
try:
query = f"""
SELECT table_name
FROM system_schema.tables
WHERE keyspace_name = '{self.keyspace}'
"""
rows = self.session.execute(query)
return [row.table_name for row in rows]
except Exception as e:
logger.error(f"Failed to list collections: {e}")
return []
def delete_col(self):
"""Delete the collection (table)."""
try:
query = f"""
DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name}
"""
self.session.execute(query)
logger.info(f"Deleted collection '{self.collection_name}'")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def col_info(self) -> dict[str, Any]:
"""
Get information about the collection.
Returns:
Dict[str, Any]: Collection information
"""
try:
# Get row count (approximate)
query = f"""
SELECT COUNT(*) as count
FROM {self.keyspace}.{self.collection_name}
"""
row = self.session.execute(query).one()
count = row.count if row else 0
return {
"name": self.collection_name,
"keyspace": self.keyspace,
"count": count,
"vector_dims": self.embedding_model_dims,
}
except Exception as e:
logger.error(f"Failed to get collection info: {e}")
return {}
def list(
self, filters: dict | None = None, limit: int = 100
) -> list[list[OutputData]]:
"""
List all vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int): Number of vectors to return
Returns:
List[List[OutputData]]: List of vectors
"""
try:
query = f"""
SELECT id, vector, payload
FROM {self.keyspace}.{self.collection_name}
LIMIT {limit}
"""
rows = self.session.execute(query)
results = []
for row in rows:
# Apply filters if provided
if filters:
try:
payload = json.loads(row.payload) if row.payload else {}
match = all(payload.get(k) == v for k, v in filters.items())
if not match:
continue
except json.JSONDecodeError:
continue
results.append(
OutputData(
id=row.id,
score=None,
payload=json.loads(row.payload) if row.payload else {},
)
)
return [results]
except Exception as e:
logger.error(f"Failed to list vectors: {e}")
return [[]]
def reset(self):
"""Reset the collection by truncating it."""
try:
logger.warning(f"Resetting collection {self.collection_name}...")
query = f"""
TRUNCATE TABLE {self.keyspace}.{self.collection_name}
"""
self.session.execute(query)
logger.info(f"Collection '{self.collection_name}' has been reset")
except Exception as e:
logger.error(f"Failed to reset collection: {e}")
raise
def __del__(self):
"""Close the cluster connection when the object is deleted."""
try:
if self.cluster:
self.cluster.shutdown()
logger.info("Cassandra cluster connection closed")
except Exception:
pass