import json
import logging
import uuid
from datetime import date, datetime
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import (
ColumnInfo,
ColumnTypeName,
DataSourceFormat,
PrimaryKeyConstraint,
TableConstraint,
TableType,
)
from databricks.sdk.service.vectorsearch import (
DeltaSyncVectorIndexSpecRequest,
DirectAccessVectorIndexSpec,
EmbeddingSourceColumn,
EmbeddingVectorColumn,
VectorIndexType,
)
from pydantic import BaseModel
from selfmemory.memory.utils import extract_json
from selfmemory.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class MemoryResult(BaseModel):
id: str | None = None
score: float | None = None
payload: dict | None = None
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
class Databricks(VectorStoreBase):
def __init__(
self,
workspace_url: str,
access_token: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
azure_client_id: str | None = None,
azure_client_secret: str | None = None,
endpoint_name: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
collection_name: str = "mem0",
index_type: str = "DELTA_SYNC",
embedding_model_endpoint_name: str | None = None,
embedding_dimension: int = 1536,
endpoint_type: str = "STANDARD",
pipeline_type: str = "TRIGGERED",
warehouse_name: str | None = None,
query_type: str = "ANN",
):
"""
Initialize the Databricks Vector Search vector store.
Args:
workspace_url (str): Databricks workspace URL.
access_token (str, optional): Personal access token for authentication.
client_id (str, optional): Service principal client ID for authentication.
client_secret (str, optional): Service principal client secret for authentication.
azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks).
azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks).
endpoint_name (str): Vector search endpoint name.
catalog (str): Unity Catalog catalog name.
schema (str): Unity Catalog schema name.
table_name (str): Source Delta table name.
index_name (str, optional): Vector search index name (default: "mem0").
index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD").
pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED").
warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse).
query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN").
"""
# Basic identifiers
self.workspace_url = workspace_url
self.endpoint_name = endpoint_name
self.catalog = catalog
self.schema = schema
self.table_name = table_name
self.fully_qualified_table_name = (
f"{self.catalog}.{self.schema}.{self.table_name}"
)
self.index_name = collection_name
self.fully_qualified_index_name = (
f"{self.catalog}.{self.schema}.{self.index_name}"
)
# Configuration
self.index_type = index_type
self.embedding_model_endpoint_name = embedding_model_endpoint_name
self.embedding_dimension = embedding_dimension
self.endpoint_type = endpoint_type
self.pipeline_type = pipeline_type
self.query_type = query_type
# Schema
self.columns = [
ColumnInfo(
name="memory_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
nullable=False,
comment="Primary key",
position=0,
),
ColumnInfo(
name="hash",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Hash of the memory content",
position=1,
),
ColumnInfo(
name="agent_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the agent",
position=2,
),
ColumnInfo(
name="run_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the run",
position=3,
),
ColumnInfo(
name="user_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the user",
position=4,
),
ColumnInfo(
name="memory",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Memory content",
position=5,
),
ColumnInfo(
name="metadata",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Additional metadata",
position=6,
),
ColumnInfo(
name="created_at",
type_name=ColumnTypeName.TIMESTAMP,
type_text="timestamp",
type_json='{"type":"timestamp"}',
comment="Creation timestamp",
position=7,
),
ColumnInfo(
name="updated_at",
type_name=ColumnTypeName.TIMESTAMP,
type_text="timestamp",
type_json='{"type":"timestamp"}',
comment="Last update timestamp",
position=8,
),
]
if self.index_type == VectorIndexType.DIRECT_ACCESS:
self.columns.append(
ColumnInfo(
name="embedding",
type_name=ColumnTypeName.ARRAY,
type_text="array<float>",
type_json='{"type":"array","element":"float","element_nullable":false}',
nullable=True,
comment="Embedding vector",
position=9,
)
)
self.column_names = [col.name for col in self.columns]
# Initialize Databricks workspace client
client_config = {}
if client_id and client_secret:
client_config.update(
{
"host": workspace_url,
"client_id": client_id,
"client_secret": client_secret,
}
)
elif azure_client_id and azure_client_secret:
client_config.update(
{
"host": workspace_url,
"azure_client_id": azure_client_id,
"azure_client_secret": azure_client_secret,
}
)
elif access_token:
client_config.update({"host": workspace_url, "token": access_token})
else:
# Try automatic authentication
client_config["host"] = workspace_url
try:
self.client = WorkspaceClient(**client_config)
logger.info("Initialized Databricks workspace client")
except Exception as e:
logger.error(f"Failed to initialize Databricks workspace client: {e}")
raise
# Get the warehouse ID by name
self.warehouse_id = next(
(w.id for w in self.client.warehouses.list() if w.name == warehouse_name),
None,
)
# Initialize endpoint (required in Databricks)
self._ensure_endpoint_exists()
# Check if index exists and create if needed
collections = self.list_cols()
if self.fully_qualified_index_name not in collections:
self.create_col()
def _ensure_endpoint_exists(self):
"""Ensure the vector search endpoint exists, create if it doesn't."""
try:
self.client.vector_search_endpoints.get_endpoint(
endpoint_name=self.endpoint_name
)
logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists")
except Exception:
# Endpoint doesn't exist, create it
try:
logger.info(
f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'"
)
self.client.vector_search_endpoints.create_endpoint_and_wait(
name=self.endpoint_name, endpoint_type=self.endpoint_type
)
logger.info(
f"Successfully created vector search endpoint '{self.endpoint_name}'"
)
except Exception as e:
logger.error(
f"Failed to create vector search endpoint '{self.endpoint_name}': {e}"
)
raise
def _ensure_source_table_exists(self):
"""Ensure the source Delta table exists with the proper schema."""
check = self.client.tables.exists(self.fully_qualified_table_name)
if check.table_exists:
logger.info(
f"Source table '{self.fully_qualified_table_name}' already exists"
)
else:
logger.info(
f"Source table '{self.fully_qualified_table_name}' does not exist, creating it..."
)
self.client.tables.create(
name=self.table_name,
catalog_name=self.catalog,
schema_name=self.schema,
table_type=TableType.MANAGED,
data_source_format=DataSourceFormat.DELTA,
storage_location=None, # Use default storage location
columns=self.columns,
properties={"delta.enableChangeDataFeed": "true"},
)
logger.info(
f"Successfully created source table '{self.fully_qualified_table_name}'"
)
self.client.table_constraints.create(
full_name_arg="logistics_dev.ai.dev_memory",
constraint=TableConstraint(
primary_key_constraint=PrimaryKeyConstraint(
name="pk_dev_memory", # Name of the primary key constraint
child_columns=[
"memory_id"
], # Columns that make up the primary key
)
),
)
logger.info(
f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'"
)
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index).
Args:
name (str, optional): Index name. If provided, will create a new index using the provided source_table_name.
vector_size (int, optional): Vector dimension size.
distance (str, optional): Distance metric (not directly applicable for Databricks).
Returns:
The index object.
"""
# Determine index configuration
embedding_dims = vector_size or self.embedding_dimension
embedding_source_columns = [
EmbeddingSourceColumn(
name="memory",
embedding_model_endpoint_name=self.embedding_model_endpoint_name,
)
]
logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'")
# First, ensure the source Delta table exists
self._ensure_source_table_exists()
if self.index_type not in [
VectorIndexType.DELTA_SYNC,
VectorIndexType.DIRECT_ACCESS,
]:
raise ValueError(
"index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'"
)
try:
if self.index_type == VectorIndexType.DELTA_SYNC:
index = self.client.vector_search_indexes.create_index(
name=self.fully_qualified_index_name,
endpoint_name=self.endpoint_name,
primary_key="memory_id",
index_type=self.index_type,
delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest(
source_table=self.fully_qualified_table_name,
pipeline_type=self.pipeline_type,
columns_to_sync=self.column_names,
embedding_source_columns=embedding_source_columns,
),
)
logger.info(
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type"
)
return index
if self.index_type == VectorIndexType.DIRECT_ACCESS:
index = self.client.vector_search_indexes.create_index(
name=self.fully_qualified_index_name,
endpoint_name=self.endpoint_name,
primary_key="memory_id",
index_type=self.index_type,
direct_access_index_spec=DirectAccessVectorIndexSpec(
embedding_source_columns=embedding_source_columns,
embedding_vector_columns=[
EmbeddingVectorColumn(
name="embedding", embedding_dimension=embedding_dims
)
],
),
)
logger.info(
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type"
)
return index
except Exception as e:
logger.error(
f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}"
)
def _format_sql_value(self, v):
"""
Format a Python value into a safe SQL literal for Databricks.
"""
if v is None:
return "NULL"
if isinstance(v, bool):
return "TRUE" if v else "FALSE"
if isinstance(v, (int, float)):
return str(v)
if isinstance(v, (datetime, date)):
return f"'{v.isoformat()}'"
if isinstance(v, list):
# Render arrays (assume numeric or string elements)
elems = []
for x in v:
if x is None:
elems.append("NULL")
elif isinstance(x, (int, float)):
elems.append(str(x))
else:
s = str(x).replace("'", "''")
elems.append(f"'{s}'")
return f"array({', '.join(elems)})"
if isinstance(v, dict):
try:
s = json.dumps(v)
except Exception:
s = str(v)
s = s.replace("'", "''")
return f"'{s}'"
# Fallback: treat as string
s = str(v).replace("'", "''")
return f"'{s}'"
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
# Determine the number of items to process
num_items = len(payloads) if payloads else len(vectors) if vectors else 0
value_tuples = []
for i in range(num_items):
values = []
for col in self.columns:
if col.name == "memory_id":
val = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
elif col.name == "embedding":
val = vectors[i] if vectors and i < len(vectors) else []
elif col.name == "memory":
val = (
payloads[i].get("data")
if payloads and i < len(payloads)
else None
)
else:
val = (
payloads[i].get(col.name)
if payloads and i < len(payloads)
else None
)
values.append(val)
formatted = [self._format_sql_value(v) for v in values]
value_tuples.append(f"({', '.join(formatted)})")
insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}"
# Execute the insert
try:
response = self.client.statement_execution.execute_statement(
statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(
f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}"
)
return
logger.error(f"Failed to insert items: {response.status.error}")
raise Exception(f"Insert operation failed: {response.status.error}")
except Exception as e:
logger.error(f"Insert operation failed: {e}")
raise
def search(
self, query: str, vectors: list, limit: int = 5, filters: dict = None
) -> list[MemoryResult]:
"""
Search for similar vectors or text using the Databricks Vector Search index.
Args:
query (str): Search query text (for text-based search).
vectors (list): Query vector (for vector-based search).
limit (int): Maximum number of results.
filters (dict): Filters to apply.
Returns:
List of MemoryResult objects.
"""
try:
filters_json = json.dumps(filters) if filters else None
# Choose query type
if self.index_type == VectorIndexType.DELTA_SYNC and query:
# Text-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=query,
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors:
# Vector-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_vector=vectors,
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
else:
raise ValueError(
"Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS."
)
# Parse results
result_data = (
sdk_results.result if hasattr(sdk_results, "result") else sdk_results
)
data_array = (
result_data.data_array
if getattr(result_data, "data_array", None)
else []
)
memory_results = []
for row in data_array:
# Map columns to values
row_dict = (
dict(zip(self.column_names, row, strict=False))
if isinstance(row, (list, tuple))
else row
)
score = row_dict.get("score") or (
row[-1]
if isinstance(row, (list, tuple))
and len(row) > len(self.column_names)
else None
)
payload = {k: row_dict.get(k) for k in self.column_names}
payload["data"] = payload.get("memory", "")
memory_id = row_dict.get("memory_id") or row_dict.get("id")
memory_results.append(
MemoryResult(id=memory_id, score=score, payload=payload)
)
return memory_results
except Exception as e:
logger.error(f"Search failed: {e}")
raise
def delete(self, vector_id):
"""
Delete a vector by ID from the Delta table.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
logger.info(
f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}"
)
delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'"
response = self.client.statement_execution.execute_statement(
statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(f"Successfully deleted vector with ID {vector_id}")
else:
logger.error(
f"Failed to delete vector with ID {vector_id}: {response.status.error}"
)
except Exception as e:
logger.error(f"Delete operation failed for vector ID {vector_id}: {e}")
raise
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload in the Delta table.
Args:
vector_id (str): ID of the vector to update.
vector (list, optional): New vector values.
payload (dict, optional): New payload data.
"""
update_sql = f"UPDATE {self.fully_qualified_table_name} SET "
set_clauses = []
if not vector_id:
logger.error("vector_id is required for update operation")
return
if vector is not None:
if not isinstance(vector, list):
logger.error("vector must be a list of float values")
return
set_clauses.append(f"embedding = {vector}")
if payload:
if not isinstance(payload, dict):
logger.error("payload must be a dictionary")
return
for key, value in payload.items():
if key not in excluded_keys:
set_clauses.append(f"{key} = '{value}'")
if not set_clauses:
logger.error("No fields to update")
return
update_sql += ", ".join(set_clauses)
update_sql += f" WHERE memory_id = '{vector_id}'"
try:
logger.info(
f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}"
)
response = self.client.statement_execution.execute_statement(
statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(f"Successfully updated vector with ID {vector_id}")
else:
logger.error(
f"Failed to update vector with ID {vector_id}: {response.status.error}"
)
except Exception as e:
logger.error(f"Update operation failed for vector ID {vector_id}: {e}")
raise
def get(self, vector_id) -> MemoryResult:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
MemoryResult: The retrieved vector.
"""
try:
# Use query with ID filter to retrieve the specific vector
filters = {"memory_id": vector_id}
filters_json = json.dumps(filters)
results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=" ", # Empty query, rely on filters
num_results=1,
query_type=self.query_type,
filters_json=filters_json,
)
# Process results
result_data = results.result if hasattr(results, "result") else results
data_array = (
result_data.data_array if hasattr(result_data, "data_array") else []
)
if not data_array:
raise KeyError(f"Vector with ID {vector_id} not found")
result = data_array[0]
columns = columns = (
[col.name for col in results.manifest.columns]
if results.manifest and results.manifest.columns
else []
)
row_data = dict(zip(columns, result, strict=False))
# Build payload following the standard schema
payload = {
"hash": row_data.get("hash", "unknown"),
"data": row_data.get("memory", row_data.get("data", "unknown")),
"created_at": row_data.get("created_at"),
}
# Add updated_at if available
if "updated_at" in row_data:
payload["updated_at"] = row_data.get("updated_at")
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in row_data:
payload[field] = row_data[field]
# Add metadata
if "metadata" in row_data and row_data.get("metadata"):
try:
metadata = json.loads(extract_json(row_data["metadata"]))
payload.update(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Failed to parse metadata: {row_data.get('metadata')}"
)
memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id))
return MemoryResult(id=memory_id, payload=payload)
except Exception as e:
logger.error(f"Failed to get vector with ID {vector_id}: {e}")
raise
def list_cols(self) -> list[str]:
"""
List all collections (indexes).
Returns:
List of index names.
"""
try:
indexes = self.client.vector_search_indexes.list_indexes(
endpoint_name=self.endpoint_name
)
return [idx.name for idx in indexes]
except Exception as e:
logger.error(f"Failed to list collections: {e}")
raise
def delete_col(self):
"""
Delete the current collection (index).
"""
try:
# Try fully qualified first
try:
self.client.vector_search_indexes.delete_index(
index_name=self.fully_qualified_index_name
)
logger.info(
f"Successfully deleted index '{self.fully_qualified_index_name}'"
)
except Exception:
self.client.vector_search_indexes.delete_index(
index_name=self.index_name
)
logger.info(
f"Successfully deleted index '{self.index_name}' (short name)"
)
except Exception as e:
logger.error(f"Failed to delete index '{self.index_name}': {e}")
raise
def col_info(self, name=None):
"""
Get information about a collection (index).
Args:
name (str, optional): Index name. Defaults to current index.
Returns:
Dict: Index information.
"""
try:
index_name = name or self.index_name
index = self.client.vector_search_indexes.get_index(index_name=index_name)
return {"name": index.name, "fields": self.columns}
except Exception as e:
logger.error(
f"Failed to get info for index '{name or self.index_name}': {e}"
)
raise
def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
"""
List all recent created memories from the vector store.
Args:
filters (dict, optional): Filters to apply.
limit (int, optional): Maximum number of results.
Returns:
List containing list of MemoryResult objects.
"""
try:
filters_json = json.dumps(filters) if filters else None
num_results = limit or 100
columns = self.column_names
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=columns,
query_text=" ",
num_results=num_results,
query_type=self.query_type,
filters_json=filters_json,
)
result_data = (
sdk_results.result if hasattr(sdk_results, "result") else sdk_results
)
data_array = (
result_data.data_array if hasattr(result_data, "data_array") else []
)
memory_results = []
for row in data_array:
row_dict = (
dict(zip(columns, row, strict=False))
if isinstance(row, (list, tuple))
else row
)
payload = {k: row_dict.get(k) for k in columns}
# Parse metadata if present
if "metadata" in payload and payload["metadata"]:
try:
payload.update(json.loads(payload["metadata"]))
except Exception:
pass
memory_id = row_dict.get("memory_id") or row_dict.get("id")
payload["data"] = payload["memory"]
memory_results.append(MemoryResult(id=memory_id, payload=payload))
return [memory_results]
except Exception as e:
logger.error(f"Failed to list memories: {e}")
return []
def reset(self):
"""Reset the vector search index and underlying source table.
This will attempt to delete the existing index (both fully qualified and short name forms
for robustness), drop the backing Delta table, recreate the table with the expected schema,
and finally recreate the index. Use with caution as all existing data will be removed.
"""
fq_index = self.fully_qualified_index_name
logger.warning(f"Resetting Databricks vector search index '{fq_index}'...")
try:
# Try deleting via fully qualified name first
try:
self.client.vector_search_indexes.delete_index(index_name=fq_index)
logger.info(f"Deleted index '{fq_index}'")
except Exception as e_fq:
logger.debug(
f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name..."
)
try:
# Fallback to existing helper which may use short name
self.delete_col()
except Exception as e_short:
logger.debug(
f"Failed deleting short index name '{self.index_name}': {e_short}"
)
# Drop the backing table (if it exists)
try:
drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}"
resp = self.client.statement_execution.execute_statement(
statement=drop_sql,
warehouse_id=self.warehouse_id,
wait_timeout="30s",
)
if getattr(resp.status, "state", None) == "SUCCEEDED":
logger.info(f"Dropped table '{self.fully_qualified_table_name}'")
else:
logger.warning(
f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}"
)
except Exception as e_drop:
logger.warning(
f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}"
)
# Recreate table & index
self._ensure_source_table_exists()
self.create_col()
logger.info(f"Successfully reset index '{fq_index}'")
except Exception as e:
logger.error(f"Error resetting index '{fq_index}': {e}")
raise