memory_server.py•8.43 kB
"""
Memory Server - Main persistent memory server
Integrates ChromaDB + Embeddings + Hierarchical Compression
"""
import logging
import uuid
from datetime import datetime
from typing import Any, Optional
logger = logging.getLogger(__name__)
class MemoryServer:
"""
Persistent memory server with RAG and hierarchical compression
Stores and retrieves knowledge efficiently
"""
def __init__(
self,
embedding_service,
db_path: str = "./chroma_db",
collection_name: str = "continuo_memory",
hierarchical_compressor=None,
):
"""
Args:
embedding_service: EmbeddingService instance
db_path: Path for ChromaDB persistence
collection_name: Collection name
hierarchical_compressor: HierarchicalCompressor instance
"""
self.embedding_service = embedding_service
self.db_path = db_path
self.collection_name = collection_name
self.compressor = hierarchical_compressor
self._client = None
self._collection = None # type: ignore[assignment]
self._initialize_db()
def _initialize_db(self):
"""Initialize ChromaDB"""
try:
import chromadb
from chromadb.config import Settings
self._client = chromadb.PersistentClient(
path=self.db_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
),
)
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Persistent context storage"},
)
except ImportError:
raise ImportError("chromadb not installed. Run: pip install chromadb")
def store(
self,
text: str,
metadata: Optional[dict[str, Any]] = None,
doc_id: Optional[str] = None,
level: str = "N0",
) -> str:
"""
Store text/code in persistent memory
Args:
text: Content to store
metadata: Additional metadata
doc_id: Custom ID (generates UUID if None)
level: Hierarchical level (N0, N1, N2)
Returns:
Stored document ID
"""
if not text or not text.strip():
raise ValueError("Empty text cannot be stored")
doc_id = doc_id or str(uuid.uuid4())
metadata = metadata or {}
# Add control metadata
metadata.update(
{
"timestamp": datetime.utcnow().isoformat(),
"level": level,
"access_count": 0,
}
)
# Generate embedding
embedding = self.embedding_service.encode(text)[0]
# Store in ChromaDB
if self._collection is not None:
self._collection.add(
ids=[doc_id],
documents=[text],
embeddings=[embedding],
metadatas=[metadata],
)
return doc_id
def retrieve(
self,
query: str,
top_k: int = 5,
level_filter: Optional[str] = None,
metadata_filter: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Retrieve similar documents via semantic search
Args:
query: Search query
top_k: Number of results
level_filter: Filter by level (N0, N1, N2)
metadata_filter: Additional filters
Returns:
Dict with documents, metadatas, distances, ids
"""
# Generate query embedding
query_embedding = self.embedding_service.encode(query)[0]
# Prepare filters
where_filter = {}
if level_filter:
where_filter["level"] = level_filter
if metadata_filter:
where_filter.update(metadata_filter)
# Search in ChromaDB
if self._collection is None:
return {
"documents": [],
"metadatas": [],
"distances": [],
"ids": [],
}
results = self._collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_filter if where_filter else None,
)
# Increment access counter
if results["ids"] and results["ids"][0]:
self._increment_access_counts(results["ids"][0])
# Apply hierarchical compression if configured
if self.compressor and results["documents"]:
results = self._apply_hierarchical_retrieval(results, query_embedding)
return {
"documents": results["documents"][0] if results["documents"] else [],
"metadatas": results["metadatas"][0] if results["metadatas"] else [],
"distances": results["distances"][0] if results["distances"] else [],
"ids": results["ids"][0] if results["ids"] else [],
}
def _increment_access_counts(self, doc_ids: list[str]):
"""Increment access counter for documents"""
if self._collection is None:
return
for doc_id in doc_ids:
try:
result = self._collection.get(ids=[doc_id])
if result["metadatas"]:
metadata = result["metadatas"][0]
metadata["access_count"] = metadata.get("access_count", 0) + 1
self._collection.update(
ids=[doc_id],
metadatas=[metadata],
)
except Exception as e:
logger.warning(f"Error incrementing access_count: {e}")
def _apply_hierarchical_retrieval(
self, results: dict[str, Any], query_embedding: list[float]
) -> dict[str, Any]:
"""
Apply hierarchical compression to results
Prioritizes most relevant items in working set
"""
if not results["ids"] or not results["ids"][0]:
return results
# Rebuild items with embeddings
items = []
for i, doc_id in enumerate(results["ids"][0]):
items.append(
{
"id": doc_id,
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"embedding": query_embedding, # Simplification
}
)
# Select optimized working set
working_set = self.compressor.select_working_set(items, query_embedding)
# Rebuild response format
return {
"documents": [[item["text"] for item in working_set]],
"metadatas": [[item["metadata"] for item in working_set]],
"distances": [[item["distance"] for item in working_set]],
"ids": [[item["id"] for item in working_set]],
}
def get_stats(self) -> dict[str, Any]:
"""Return memory statistics"""
if self._collection is None:
return {
"total_documents": 0,
"by_level": {
"N0_chunks": 0,
"N1_micro_summaries": 0,
"N2_meta_summaries": 0,
},
"db_path": self.db_path,
"collection": self.collection_name,
}
total_docs = self._collection.count()
# Count by level
n0_count = len(self._collection.get(where={"level": "N0"}).get("ids", []))
n1_count = len(self._collection.get(where={"level": "N1"}).get("ids", []))
n2_count = len(self._collection.get(where={"level": "N2"}).get("ids", []))
return {
"total_documents": total_docs,
"by_level": {
"N0_chunks": n0_count,
"N1_micro_summaries": n1_count,
"N2_meta_summaries": n2_count,
},
"db_path": self.db_path,
"collection": self.collection_name,
}
def reset(self):
"""Reset entire memory (use with caution!)"""
self._client.delete_collection(self.collection_name)
self._collection = self._client.create_collection(name=self.collection_name)
logger.warning("Memory reset!")