import json
import logging
from datetime import datetime
import numpy as np
import pytz
import valkey
from pydantic import BaseModel
from valkey.exceptions import ResponseError
from selfmemory.memory.utils import extract_json
from selfmemory.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
# Default fields for the Valkey index
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{
"name": "memory",
"type": "tag",
}, # Using TAG instead of TEXT for Valkey compatibility
{
"name": "metadata",
"type": "tag",
}, # Using TAG instead of TEXT for Valkey compatibility
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
},
},
]
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
class OutputData(BaseModel):
id: str
score: float
payload: dict
class ValkeyDB(VectorStoreBase):
def __init__(
self,
valkey_url: str,
collection_name: str,
embedding_model_dims: int,
timezone: str = "UTC",
index_type: str = "hnsw",
hnsw_m: int = 16,
hnsw_ef_construction: int = 200,
hnsw_ef_runtime: int = 10,
):
"""
Initialize the Valkey vector store.
Args:
valkey_url (str): Valkey URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
timezone (str, optional): Timezone for timestamps. Defaults to "UTC".
index_type (str, optional): Index type ('hnsw' or 'flat'). Defaults to "hnsw".
hnsw_m (int, optional): HNSW M parameter (connections per node). Defaults to 16.
hnsw_ef_construction (int, optional): HNSW ef_construction parameter. Defaults to 200.
hnsw_ef_runtime (int, optional): HNSW ef_runtime parameter. Defaults to 10.
"""
self.embedding_model_dims = embedding_model_dims
self.collection_name = collection_name
self.prefix = f"mem0:{collection_name}"
self.timezone = timezone
self.index_type = index_type.lower()
self.hnsw_m = hnsw_m
self.hnsw_ef_construction = hnsw_ef_construction
self.hnsw_ef_runtime = hnsw_ef_runtime
# Validate index type
if self.index_type not in ["hnsw", "flat"]:
raise ValueError(
f"Invalid index_type: {index_type}. Must be 'hnsw' or 'flat'"
)
# Connect to Valkey
try:
self.client = valkey.from_url(valkey_url)
logger.debug(f"Successfully connected to Valkey at {valkey_url}")
except Exception as e:
logger.exception(f"Failed to connect to Valkey at {valkey_url}: {e}")
raise
# Create the index schema
self._create_index(embedding_model_dims)
def _build_index_schema(
self, collection_name, embedding_dims, distance_metric, prefix
):
"""
Build the FT.CREATE command for index creation.
Args:
collection_name (str): Name of the collection/index
embedding_dims (int): Vector embedding dimensions
distance_metric (str): Distance metric (e.g., "COSINE", "L2", "IP")
prefix (str): Key prefix for the index
Returns:
list: Complete FT.CREATE command as list of arguments
"""
# Build the vector field configuration based on index type
if self.index_type == "hnsw":
vector_config = [
"embedding",
"VECTOR",
"HNSW",
"12", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric, M, m, EF_CONSTRUCTION, ef_construction, EF_RUNTIME, ef_runtime
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
"M",
str(self.hnsw_m),
"EF_CONSTRUCTION",
str(self.hnsw_ef_construction),
"EF_RUNTIME",
str(self.hnsw_ef_runtime),
]
elif self.index_type == "flat":
vector_config = [
"embedding",
"VECTOR",
"FLAT",
"6", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
]
else:
# This should never happen due to constructor validation, but be defensive
raise ValueError(
f"Unsupported index_type: {self.index_type}. Must be 'hnsw' or 'flat'"
)
# Build the complete command (comma is default separator for TAG fields)
cmd = [
"FT.CREATE",
collection_name,
"ON",
"HASH",
"PREFIX",
"1",
prefix,
"SCHEMA",
"memory_id",
"TAG",
"hash",
"TAG",
"agent_id",
"TAG",
"run_id",
"TAG",
"user_id",
"TAG",
"memory",
"TAG",
"metadata",
"TAG",
"created_at",
"NUMERIC",
"updated_at",
"NUMERIC",
] + vector_config
return cmd
def _create_index(self, embedding_model_dims):
"""
Create the search index with the specified schema.
Args:
embedding_model_dims (int): Dimensions for the vector embeddings.
Raises:
ValueError: If the search module is not available.
Exception: For other errors during index creation.
"""
# Check if the search module is available
try:
# Try to execute a search command
self.client.execute_command("FT._LIST")
except ResponseError as e:
if "unknown command" in str(e).lower():
raise ValueError(
"Valkey search module is not available. Please ensure Valkey is running with the search module enabled. "
"The search module can be loaded using the --loadmodule option with the valkey-search library. "
"For installation and setup instructions, refer to the Valkey Search documentation."
)
logger.exception(f"Error checking search module: {e}")
raise
# Check if the index already exists
try:
self.client.ft(self.collection_name).info()
return
except ResponseError as e:
if "not found" not in str(e).lower():
logger.exception(f"Error checking index existence: {e}")
raise
# Build and execute the index creation command
cmd = self._build_index_schema(
self.collection_name,
embedding_model_dims,
"COSINE", # Fixed distance metric for initialization
self.prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(
f"Successfully created {self.index_type.upper()} index {self.collection_name}"
)
except Exception as e:
logger.exception(f"Error creating index {self.collection_name}: {e}")
raise
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index) in Valkey.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.collection_name
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "COSINE"
prefix = f"mem0:{collection_name}"
# Try to drop the index if it exists (cleanup before creation)
self._drop_index(collection_name, log_level="silent")
# Build and execute the index creation command
cmd = self._build_index_schema(
collection_name,
embedding_dims,
distance_metric, # Configurable distance metric
prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(
f"Successfully created {self.index_type.upper()} index {collection_name}"
)
# Update instance attributes if creating a new collection
if name:
self.collection_name = collection_name
self.prefix = prefix
return self.client.ft(collection_name)
except Exception as e:
logger.exception(f"Error creating collection {collection_name}: {e}")
raise
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors and their payloads into the index.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to the vectors.
ids (list, optional): List of IDs for the vectors.
"""
for vector, payload, id in zip(vectors, payloads, ids, strict=False):
try:
# Create the key for the hash
key = f"{self.prefix}:{id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(
pytz.timezone(self.timezone)
).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": id,
"hash": payload.get(
"hash", f"hash_{id}"
), # Use a default hash if not provided
"memory": payload.get(
"data", f"data_{id}"
), # Use a default data if not provided
"created_at": int(
datetime.fromisoformat(payload["created_at"]).timestamp()
),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps(
{k: v for k, v in payload.items() if k not in excluded_keys}
)
# Store in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully inserted vector with ID {id}")
except KeyError as e:
logger.error(
f"Error inserting vector with ID {id}: Missing required field {e}"
)
except Exception as e:
logger.exception(f"Error inserting vector with ID {id}: {e}")
raise
def _build_search_query(self, knn_part, filters=None):
"""
Build a search query string with filters.
Args:
knn_part (str): The KNN part of the query.
filters (dict, optional): Filters to apply to the search. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search. Multiple filters are
combined with AND logic (space-separated).
Returns:
str: The complete search query string in format "filter_expr =>[KNN...]"
or "*=>[KNN...]" if no valid filters.
"""
# No filters, just use the KNN search
if not filters or not any(value is not None for key, value in filters.items()):
return f"*=>{knn_part}"
# Build filter expression
filter_parts = []
for key, value in filters.items():
if value is not None:
# Use the correct filter syntax for Valkey
filter_parts.append(f"@{key}:{{{value}}}")
# No valid filter parts
if not filter_parts:
return f"*=>{knn_part}"
# Combine filter parts with proper syntax
filter_expr = " ".join(filter_parts)
return f"{filter_expr} =>{knn_part}"
def _execute_search(self, query, params):
"""
Execute a search query.
Args:
query (str): The search query to execute.
params (dict): The query parameters.
Returns:
The search results.
"""
try:
return self.client.ft(self.collection_name).search(
query, query_params=params
)
except ResponseError as e:
logger.error(f"Search failed with query '{query}': {e}")
raise
def _process_search_results(self, results):
"""
Process search results into OutputData objects.
Args:
results: The search results from Valkey.
Returns:
list: List of OutputData objects.
"""
memory_results = []
for doc in results.docs:
# Extract the score
score = float(doc.vector_score) if hasattr(doc, "vector_score") else None
# Create the payload
payload = {
"hash": doc.hash,
"data": doc.memory,
"created_at": self._format_timestamp(
int(doc.created_at), self.timezone
),
}
# Add updated_at if available
if hasattr(doc, "updated_at"):
payload["updated_at"] = self._format_timestamp(
int(doc.updated_at), self.timezone
)
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if hasattr(doc, field):
payload[field] = getattr(doc, field)
# Add metadata
if hasattr(doc, "metadata"):
try:
metadata = json.loads(extract_json(doc.metadata))
payload.update(metadata)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse metadata: {e}")
# Create the result
memory_results.append(
OutputData(id=doc.memory_id, score=score, payload=payload)
)
return memory_results
def search(
self,
query: str,
vectors: list,
limit: int = 5,
filters: dict = None,
ef_runtime: int = None,
):
"""
Search for similar vectors in the index.
Args:
query (str): The search query.
vectors (list): The vector to search for.
limit (int, optional): Maximum number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
ef_runtime (int, optional): HNSW ef_runtime parameter for this query. Only used with HNSW index. Defaults to None.
Returns:
list: List of OutputData objects.
"""
# Convert the vector to bytes
vector_bytes = np.array(vectors, dtype=np.float32).tobytes()
# Build the KNN part with optional EF_RUNTIME for HNSW
if self.index_type == "hnsw" and ef_runtime is not None:
knn_part = f"[KNN {limit} @embedding $vec_param EF_RUNTIME {ef_runtime} AS vector_score]"
else:
# For FLAT indexes or when ef_runtime is None, use basic KNN
knn_part = f"[KNN {limit} @embedding $vec_param AS vector_score]"
# Build the complete query
q = self._build_search_query(knn_part, filters)
# Log the query for debugging (only in debug mode)
logger.debug(f"Valkey search query: {q}")
# Set up the query parameters
params = {"vec_param": vector_bytes}
# Execute the search
results = self._execute_search(q, params)
# Process the results
return self._process_search_results(results)
def delete(self, vector_id):
"""
Delete a vector from the index.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
key = f"{self.prefix}:{vector_id}"
self.client.delete(key)
logger.debug(f"Successfully deleted vector with ID {vector_id}")
except Exception as e:
logger.exception(f"Error deleting vector with ID {vector_id}: {e}")
raise
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector in the index.
Args:
vector_id (str): ID of the vector to update.
vector (list, optional): New vector data.
payload (dict, optional): New payload data.
"""
try:
key = f"{self.prefix}:{vector_id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(
pytz.timezone(self.timezone)
).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": vector_id,
"hash": payload.get(
"hash", f"hash_{vector_id}"
), # Use a default hash if not provided
"memory": payload.get(
"data", f"data_{vector_id}"
), # Use a default data if not provided
"created_at": int(
datetime.fromisoformat(payload["created_at"]).timestamp()
),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add updated_at if available
if "updated_at" in payload:
hash_data["updated_at"] = int(
datetime.fromisoformat(payload["updated_at"]).timestamp()
)
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps(
{k: v for k, v in payload.items() if k not in excluded_keys}
)
# Update in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully updated vector with ID {vector_id}")
except KeyError as e:
logger.error(
f"Error updating vector with ID {vector_id}: Missing required field {e}"
)
except Exception as e:
logger.exception(f"Error updating vector with ID {vector_id}: {e}")
raise
def _format_timestamp(self, timestamp, timezone=None):
"""
Format a timestamp with the specified timezone.
Args:
timestamp (int): The timestamp to format.
timezone (str, optional): The timezone to use. Defaults to UTC.
Returns:
str: The formatted timestamp.
"""
# Use UTC as default timezone if not specified
tz = pytz.timezone(timezone or "UTC")
return datetime.fromtimestamp(timestamp, tz=tz).isoformat(
timespec="microseconds"
)
def _process_document_fields(self, result, vector_id):
"""
Process document fields from a Valkey hash result.
Args:
result (dict): The hash result from Valkey.
vector_id (str): The vector ID.
Returns:
dict: The processed payload.
str: The memory ID.
"""
# Create the payload with error handling
payload = {}
# Convert bytes to string for text fields
for k in result:
if k not in ["embedding"]:
if isinstance(result[k], bytes):
try:
result[k] = result[k].decode("utf-8")
except UnicodeDecodeError:
# If decoding fails, keep the bytes
pass
# Add required fields with error handling
for field in ["hash", "memory", "created_at"]:
if field in result:
if field == "created_at":
try:
payload[field] = self._format_timestamp(
int(result[field]), self.timezone
)
except (ValueError, TypeError):
payload[field] = result[field]
else:
payload[field] = result[field]
else:
# Use default values for missing fields
if field == "hash" or field == "memory":
payload[field] = "unknown"
elif field == "created_at":
payload[field] = self._format_timestamp(
int(datetime.now(tz=pytz.timezone(self.timezone)).timestamp()),
self.timezone,
)
# Rename memory to data for consistency
if "memory" in payload:
payload["data"] = payload.pop("memory")
# Add updated_at if available
if "updated_at" in result:
try:
payload["updated_at"] = self._format_timestamp(
int(result["updated_at"]), self.timezone
)
except (ValueError, TypeError):
payload["updated_at"] = result["updated_at"]
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in result:
payload[field] = result[field]
# Add metadata
if "metadata" in result:
try:
metadata = json.loads(extract_json(result["metadata"]))
payload.update(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse metadata: {result.get('metadata')}")
# Use memory_id from result if available, otherwise use vector_id
memory_id = result.get("memory_id", vector_id)
return payload, memory_id
def _convert_bytes(self, data):
"""Convert bytes data back to string"""
if isinstance(data, bytes):
try:
return data.decode("utf-8")
except UnicodeDecodeError:
return data
if isinstance(data, dict):
return {
self._convert_bytes(key): self._convert_bytes(value)
for key, value in data.items()
}
if isinstance(data, list):
return [self._convert_bytes(item) for item in data]
if isinstance(data, tuple):
return tuple(self._convert_bytes(item) for item in data)
return data
def get(self, vector_id):
"""
Get a vector by ID.
Args:
vector_id (str): ID of the vector to get.
Returns:
OutputData: The retrieved vector.
"""
try:
key = f"{self.prefix}:{vector_id}"
result = self.client.hgetall(key)
if not result:
raise KeyError(f"Vector with ID {vector_id} not found")
# Convert bytes keys/values to strings
result = self._convert_bytes(result)
logger.debug(f"Retrieved result keys: {result.keys()}")
# Process the document fields
payload, memory_id = self._process_document_fields(result, vector_id)
return OutputData(id=memory_id, payload=payload, score=0.0)
except KeyError:
raise
except Exception as e:
logger.exception(f"Error getting vector with ID {vector_id}: {e}")
raise
def list_cols(self):
"""
List all collections (indices) in Valkey.
Returns:
list: List of collection names.
"""
try:
# Use the FT._LIST command to list all indices
return self.client.execute_command("FT._LIST")
except Exception as e:
logger.exception(f"Error listing collections: {e}")
raise
def _drop_index(self, collection_name, log_level="error"):
"""
Drop an index by name using the documented FT.DROPINDEX command.
Args:
collection_name (str): Name of the index to drop.
log_level (str): Logging level for missing index ("silent", "info", "error").
"""
try:
self.client.execute_command("FT.DROPINDEX", collection_name)
logger.info(f"Successfully deleted index {collection_name}")
return True
except ResponseError as e:
if "Unknown index name" in str(e):
# Index doesn't exist - handle based on context
if log_level == "silent":
pass # No logging in situations where this is expected such as initial index creation
elif log_level == "info":
logger.info(
f"Index {collection_name} doesn't exist, skipping deletion"
)
return False
# Real error - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
except Exception as e:
# Non-ResponseError exceptions - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
def delete_col(self):
"""
Delete the current collection (index).
"""
return self._drop_index(self.collection_name, log_level="info")
def col_info(self, name=None):
"""
Get information about a collection (index).
Args:
name (str, optional): Name of the collection. Defaults to None, which uses the current collection_name.
Returns:
dict: Information about the collection.
"""
try:
collection_name = name or self.collection_name
return self.client.ft(collection_name).info()
except Exception as e:
logger.exception(
f"Error getting collection info for {collection_name}: {e}"
)
raise
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
try:
collection_name = self.collection_name
logger.warning(f"Resetting index {collection_name}...")
# Delete the index
self.delete_col()
# Recreate the index
self._create_index(self.embedding_model_dims)
return True
except Exception as e:
logger.exception(f"Error resetting index {self.collection_name}: {e}")
raise
def _build_list_query(self, filters=None):
"""
Build a query for listing vectors.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search.
Returns:
str: The query string. Returns "*" if no valid filters provided.
"""
# Default query
q = "*"
# Add filters if provided
if filters and any(value is not None for key, value in filters.items()):
filter_conditions = []
for key, value in filters.items():
if value is not None:
filter_conditions.append(f"@{key}:{{{value}}}")
if filter_conditions:
q = " ".join(filter_conditions)
return q
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is without validation - wildcards, special characters,
lists, etc. are passed through literally to Valkey search.
Multiple filters are combined with AND logic.
limit (int, optional): Maximum number of results to return. Defaults to 1000
if not specified.
Returns:
list: Nested list format [[MemoryResult(), ...]] matching Redis implementation.
Each MemoryResult contains id and payload with hash, data, timestamps, etc.
"""
try:
# Since Valkey search requires vector format, use a dummy vector search
# that returns all documents by using a zero vector and large K
dummy_vector = [0.0] * self.embedding_model_dims
search_limit = limit if limit is not None else 1000 # Large default
# Use the existing search method which handles filters properly
search_results = self.search(
"", dummy_vector, limit=search_limit, filters=filters
)
# Convert search results to list format (match Redis format)
class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score
memory_results = []
for result in search_results:
# Create payload in the expected format
payload = {
"hash": result.payload.get("hash", ""),
"data": result.payload.get("data", ""),
"created_at": result.payload.get("created_at"),
"updated_at": result.payload.get("updated_at"),
}
# Add metadata (exclude system fields)
for key, value in result.payload.items():
if key not in ["data", "hash", "created_at", "updated_at"]:
payload[key] = value
# Create MemoryResult object (matching Redis format)
memory_results.append(MemoryResult(id=result.id, payload=payload))
# Return nested list format like Redis
return [memory_results]
except Exception as e:
logger.exception(f"Error in list method: {e}")
return [[]] # Return empty result on error