Memory MCP Server

import asyncio import json import time from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, cast import aiofiles from thefuzz import fuzz from ..exceptions import EntityNotFoundError, FileAccessError from ..interfaces import ( BatchOperation, BatchOperationType, BatchResult, Entity, KnowledgeGraph, Relation, SearchOptions, ) from .base import Backend @dataclass class SearchResult: entity: Entity score: float class ReentrantLock: def __init__(self): self._lock = asyncio.Lock() self._owner = None self._count = 0 async def acquire(self): current = asyncio.current_task() if self._owner == current: self._count += 1 return await self._lock.acquire() self._owner = current self._count = 1 def release(self): current = asyncio.current_task() if self._owner != current: raise RuntimeError("Lock not owned by current task") self._count -= 1 if self._count == 0: self._owner = None self._lock.release() async def __aenter__(self): await self.acquire() return self async def __aexit__(self, exc_type, exc_val, tb): self.release() class JsonlBackend(Backend): def __init__(self, memory_path: Path, cache_ttl: int = 60): self.memory_path = memory_path self.cache_ttl = cache_ttl self._cache: Optional[KnowledgeGraph] = None self._cache_timestamp: float = 0.0 self._cache_file_mtime: float = 0.0 self._dirty = False self._write_lock = ReentrantLock() self._lock = asyncio.Lock() # Transaction support: when a transaction is active, we work on separate copies. self._transaction_cache: Optional[KnowledgeGraph] = None self._transaction_indices: Optional[Dict[str, Any]] = None self._in_transaction = False self._indices: Dict[str, Any] = { "entity_names": {}, "entity_types": defaultdict(list), "relations_from": defaultdict(list), "relations_to": defaultdict(list), "relation_keys": set(), "observation_index": defaultdict(set), } async def initialize(self) -> None: self.memory_path.parent.mkdir(parents=True, exist_ok=True) if self.memory_path.exists() and self.memory_path.is_dir(): raise FileAccessError(f"Path {self.memory_path} is a directory") async def close(self) -> None: await self.flush() def _build_indices(self, graph: KnowledgeGraph) -> None: # Build indices for faster lookups. entity_names: Dict[str, Entity] = {} entity_types: Dict[str, List[Entity]] = defaultdict(list) relations_from: Dict[str, List[Relation]] = defaultdict(list) relations_to: Dict[str, List[Relation]] = defaultdict(list) relation_keys: Set[Tuple[str, str, str]] = set() for entity in graph.entities: entity_names[entity.name] = entity entity_types[entity.entityType].append(entity) for relation in graph.relations: relations_from[relation.from_].append(relation) relations_to[relation.to].append(relation) relation_keys.add((relation.from_, relation.to, relation.relationType)) self._indices["entity_names"] = entity_names self._indices["entity_types"] = entity_types self._indices["relations_from"] = relations_from self._indices["relations_to"] = relations_to self._indices["relation_keys"] = relation_keys # Build the observation index. observation_index = cast( Dict[str, Set[str]], self._indices["observation_index"] ) observation_index.clear() for entity in graph.entities: for obs in entity.observations: for word in obs.lower().split(): observation_index[word].add(entity.name) async def _check_cache(self) -> KnowledgeGraph: # During a transaction, always use the transaction snapshot. if self._in_transaction: return self._transaction_cache # type: ignore current_time = time.monotonic() file_mtime = ( self.memory_path.stat().st_mtime if self.memory_path.exists() else 0 ) needs_refresh = ( self._cache is None or (current_time - self._cache_timestamp > self.cache_ttl) or self._dirty or (file_mtime > self._cache_file_mtime) ) if needs_refresh: async with self._lock: current_time = time.monotonic() file_mtime = ( self.memory_path.stat().st_mtime if self.memory_path.exists() else 0 ) needs_refresh = ( self._cache is None or (current_time - self._cache_timestamp > self.cache_ttl) or self._dirty or (file_mtime > self._cache_file_mtime) ) if needs_refresh: try: graph = await self._load_graph_from_file() self._cache = graph self._cache_timestamp = current_time self._cache_file_mtime = file_mtime self._build_indices(graph) self._dirty = False except FileAccessError: raise except Exception as e: raise FileAccessError(f"Error loading graph: {str(e)}") from e return cast(KnowledgeGraph, self._cache) async def _load_graph_from_file(self) -> KnowledgeGraph: if not self.memory_path.exists(): return KnowledgeGraph(entities=[], relations=[]) graph = KnowledgeGraph(entities=[], relations=[]) try: async with aiofiles.open(self.memory_path, mode="r", encoding="utf-8") as f: async for line in f: line = line.strip() if not line: continue try: item = json.loads(line) if item["type"] == "entity": graph.entities.append( Entity( name=item["name"], entityType=item["entityType"], observations=item["observations"], ) ) elif item["type"] == "relation": graph.relations.append( Relation( from_=item["from"], to=item["to"], relationType=item["relationType"], ) ) except json.JSONDecodeError as e: raise FileAccessError(f"Error loading graph: {str(e)}") from e except KeyError as e: raise FileAccessError( f"Error loading graph: Missing required key {str(e)}" ) from e return graph except Exception as err: raise FileAccessError(f"Error reading file: {str(err)}") from err async def _save_graph(self, graph: KnowledgeGraph) -> None: # This function writes to disk. Note that during a transaction, it is only called on commit. temp_path = self.memory_path.with_suffix(".tmp") buffer_size = 1000 # Buffer size (number of lines) try: async with aiofiles.open(temp_path, mode="w", encoding="utf-8") as f: buffer = [] # Write entities. for entity in graph.entities: line = json.dumps( { "type": "entity", "name": entity.name, "entityType": entity.entityType, "observations": entity.observations, } ) buffer.append(line) if len(buffer) >= buffer_size: await f.write("\n".join(buffer) + "\n") buffer = [] if buffer: await f.write("\n".join(buffer) + "\n") buffer = [] # Write relations. for relation in graph.relations: line = json.dumps( { "type": "relation", "from": relation.from_, "to": relation.to, "relationType": relation.relationType, } ) buffer.append(line) if len(buffer) >= buffer_size: await f.write("\n".join(buffer) + "\n") buffer = [] if buffer: await f.write("\n".join(buffer) + "\n") temp_path.replace(self.memory_path) except Exception as err: raise FileAccessError(f"Error saving file: {str(err)}") from err finally: if temp_path.exists(): try: temp_path.unlink() except Exception: pass async def _get_current_state(self) -> Tuple[KnowledgeGraph, Dict[str, Any]]: # Returns the active graph and indices. If a transaction is in progress, # return the transaction copies; otherwise, return the persistent ones. if self._in_transaction: return self._transaction_cache, self._transaction_indices # type: ignore else: graph = await self._check_cache() return graph, self._indices async def create_entities(self, entities: List[Entity]) -> List[Entity]: async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) new_entities = [] for entity in entities: if not entity.name or not entity.entityType: raise ValueError(f"Invalid entity: {entity}") if entity.name not in existing_entities: new_entities.append(entity) existing_entities[entity.name] = entity cast(Dict[str, List[Entity]], indices["entity_types"]).setdefault( entity.entityType, [] ).append(entity) if new_entities: graph.entities.extend(new_entities) # If not in a transaction, immediately persist the change. if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() return new_entities async def delete_entities(self, entity_names: List[str]) -> List[str]: if not entity_names: return [] async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) deleted_names = [] relation_keys = cast(Set[Tuple[str, str, str]], indices["relation_keys"]) for name in entity_names: if name in existing_entities: entity = existing_entities.pop(name) entity_type_list = cast( Dict[str, List[Entity]], indices["entity_types"] ).get(entity.entityType, []) if entity in entity_type_list: entity_type_list.remove(entity) # Remove associated relations. relations_from = cast( Dict[str, List[Relation]], indices["relations_from"] ).get(name, []) relations_to = cast( Dict[str, List[Relation]], indices["relations_to"] ).get(name, []) relations_to_remove = relations_from + relations_to for relation in relations_to_remove: if relation in graph.relations: graph.relations.remove(relation) relation_keys.discard( (relation.from_, relation.to, relation.relationType) ) if relation in cast( Dict[str, List[Relation]], indices["relations_from"] ).get(relation.from_, []): cast(Dict[str, List[Relation]], indices["relations_from"])[ relation.from_ ].remove(relation) if relation in cast( Dict[str, List[Relation]], indices["relations_to"] ).get(relation.to, []): cast(Dict[str, List[Relation]], indices["relations_to"])[ relation.to ].remove(relation) deleted_names.append(name) if deleted_names: graph.entities = [ e for e in graph.entities if e.name not in deleted_names ] if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() return deleted_names async def create_relations(self, relations: List[Relation]) -> List[Relation]: async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) relation_keys = cast(Set[Tuple[str, str, str]], indices["relation_keys"]) new_relations = [] for relation in relations: if not relation.from_ or not relation.to or not relation.relationType: raise ValueError(f"Invalid relation: {relation}") if relation.from_ not in existing_entities: raise EntityNotFoundError(f"Entity not found: {relation.from_}") if relation.to not in existing_entities: raise EntityNotFoundError(f"Entity not found: {relation.to}") key = (relation.from_, relation.to, relation.relationType) if key not in relation_keys: new_relations.append(relation) relation_keys.add(key) cast( Dict[str, List[Relation]], indices["relations_from"] ).setdefault(relation.from_, []).append(relation) cast(Dict[str, List[Relation]], indices["relations_to"]).setdefault( relation.to, [] ).append(relation) if new_relations: graph.relations.extend(new_relations) if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() return new_relations async def delete_relations(self, from_: str, to: str) -> None: async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) if from_ not in existing_entities: raise EntityNotFoundError(f"Entity not found: {from_}") if to not in existing_entities: raise EntityNotFoundError(f"Entity not found: {to}") relations_from = cast( Dict[str, List[Relation]], indices["relations_from"] ).get(from_, []) relations_to_remove = [rel for rel in relations_from if rel.to == to] if relations_to_remove: graph.relations = [ rel for rel in graph.relations if rel not in relations_to_remove ] relation_keys = cast( Set[Tuple[str, str, str]], indices["relation_keys"] ) for rel in relations_to_remove: relation_keys.discard((rel.from_, rel.to, rel.relationType)) if rel in cast( Dict[str, List[Relation]], indices["relations_from"] ).get(from_, []): cast(Dict[str, List[Relation]], indices["relations_from"])[ from_ ].remove(rel) if rel in cast( Dict[str, List[Relation]], indices["relations_to"] ).get(to, []): cast(Dict[str, List[Relation]], indices["relations_to"])[ to ].remove(rel) if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() async def read_graph(self) -> KnowledgeGraph: return await self._check_cache() async def flush(self) -> None: async with self._write_lock: # During a transaction, disk is not touched until commit. if self._dirty and not self._in_transaction: graph = await self._check_cache() await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() async def search_nodes( self, query: str, options: Optional[SearchOptions] = None ) -> KnowledgeGraph: """ Search for entities and relations matching the query. If options is provided and options.fuzzy is True, fuzzy matching is used with weights and threshold. Otherwise, a simple case‐insensitive substring search is performed. Relations are returned only if both endpoints are in the set of matched entities. """ graph = await self._check_cache() matched_entities = [] if options is not None and options.fuzzy: # Use provided weights or default to 1.0 if not provided. weights = ( options.weights if options.weights is not None else {"name": 1.0, "type": 1.0, "observations": 1.0} ) q = query.strip() for entity in graph.entities: # Compute robust scores for each field. name_score = fuzz.WRatio(q, entity.name) type_score = fuzz.WRatio(q, entity.entityType) obs_score = 0 if entity.observations: # For each observation, take the best between WRatio and partial_ratio. scores = [ max(fuzz.WRatio(q, obs), fuzz.partial_ratio(q, obs)) for obs in entity.observations ] obs_score = max(scores) if scores else 0 total_score = ( name_score * weights.get("name", 1.0) + type_score * weights.get("type", 1.0) + obs_score * weights.get("observations", 1.0) ) if total_score >= options.threshold: matched_entities.append(entity) else: q = query.lower() for entity in graph.entities: if ( q in entity.name.lower() or q in entity.entityType.lower() or any(q in obs.lower() for obs in entity.observations) ): matched_entities.append(entity) matched_names = {entity.name for entity in matched_entities} matched_relations = [ rel for rel in graph.relations if rel.from_ in matched_names and rel.to in matched_names ] return KnowledgeGraph(entities=matched_entities, relations=matched_relations) async def add_observations(self, entity_name: str, observations: List[str]) -> None: if not observations: raise ValueError("Observations list cannot be empty") async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) if entity_name not in existing_entities: raise EntityNotFoundError(f"Entity not found: {entity_name}") entity = existing_entities[entity_name] updated_entity = Entity( name=entity.name, entityType=entity.entityType, observations=list(entity.observations) + observations, ) graph.entities = [ updated_entity if e.name == entity_name else e for e in graph.entities ] existing_entities[entity_name] = updated_entity entity_types = cast(Dict[str, List[Entity]], indices["entity_types"]) if entity_name in [ e.name for e in entity_types.get(updated_entity.entityType, []) ]: entity_types[updated_entity.entityType] = [ updated_entity if e.name == entity_name else e for e in entity_types[updated_entity.entityType] ] if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() async def add_batch_observations( self, observations_map: Dict[str, List[str]] ) -> None: if not observations_map: raise ValueError("Observations map cannot be empty") async with self._write_lock: graph, indices = await self._get_current_state() existing_entities = cast(Dict[str, Entity], indices["entity_names"]) entity_types = cast(Dict[str, List[Entity]], indices["entity_types"]) missing_entities = [ name for name in observations_map if name not in existing_entities ] if missing_entities: raise EntityNotFoundError( f"Entities not found: {', '.join(missing_entities)}" ) updated_entities = {} for entity_name, observations in observations_map.items(): if not observations: continue entity = existing_entities[entity_name] updated_entity = Entity( name=entity.name, entityType=entity.entityType, observations=list(entity.observations) + observations, ) updated_entities[entity_name] = updated_entity if updated_entities: graph.entities = [ updated_entities.get(e.name, e) for e in graph.entities ] for updated_entity in updated_entities.values(): existing_entities[updated_entity.name] = updated_entity et_list = entity_types.get(updated_entity.entityType, []) for i, e in enumerate(et_list): if e.name == updated_entity.name: et_list[i] = updated_entity break if not self._in_transaction: self._dirty = True await self._save_graph(graph) self._dirty = False self._cache_timestamp = time.monotonic() # # Transaction Methods # async def begin_transaction(self) -> None: async with self._write_lock: if self._in_transaction: raise ValueError("Transaction already in progress") graph = await self._check_cache() # Make deep (shallow for immutable entities) copies of state. self._transaction_cache = KnowledgeGraph( entities=list(graph.entities), relations=list(graph.relations) ) self._transaction_indices = { "entity_names": dict(self._indices["entity_names"]), "entity_types": defaultdict( list, {k: list(v) for k, v in self._indices["entity_types"].items()} ), "relations_from": defaultdict( list, {k: list(v) for k, v in self._indices["relations_from"].items()}, ), "relations_to": defaultdict( list, {k: list(v) for k, v in self._indices["relations_to"].items()} ), "relation_keys": set(self._indices["relation_keys"]), "observation_index": defaultdict( set, {k: set(v) for k, v in self._indices["observation_index"].items()}, ), } self._in_transaction = True async def rollback_transaction(self) -> None: async with self._write_lock: if not self._in_transaction: raise ValueError("No transaction in progress") # Discard the transaction state; since disk writes were deferred, the file remains unchanged. self._transaction_cache = None self._transaction_indices = None self._in_transaction = False async def commit_transaction(self) -> None: async with self._write_lock: if not self._in_transaction: raise ValueError("No transaction in progress") # Persist the transaction state to disk. await self._save_graph(cast(KnowledgeGraph, self._transaction_cache)) # Update the persistent state with the transaction snapshot. self._cache = self._transaction_cache self._indices = self._transaction_indices # type: ignore self._transaction_cache = None self._transaction_indices = None self._in_transaction = False self._dirty = False self._cache_timestamp = time.monotonic() async def execute_batch(self, operations: List[BatchOperation]) -> BatchResult: if not operations: return BatchResult( success=True, operations_completed=0, failed_operations=[], ) async with self._write_lock: try: # Start a transaction so that no disk writes occur until commit. await self.begin_transaction() completed = 0 failed_ops: List[Tuple[BatchOperation, str]] = [] # Execute each operation. for operation in operations: try: if ( operation.operation_type == BatchOperationType.CREATE_ENTITIES ): await self.create_entities(operation.data["entities"]) elif ( operation.operation_type == BatchOperationType.DELETE_ENTITIES ): await self.delete_entities(operation.data["entity_names"]) elif ( operation.operation_type == BatchOperationType.CREATE_RELATIONS ): await self.create_relations(operation.data["relations"]) elif ( operation.operation_type == BatchOperationType.DELETE_RELATIONS ): await self.delete_relations( operation.data["from_"], operation.data["to"] ) elif ( operation.operation_type == BatchOperationType.ADD_OBSERVATIONS ): await self.add_batch_observations( operation.data["observations_map"] ) else: raise ValueError( f"Unknown operation type: {operation.operation_type}" ) completed += 1 except Exception as e: failed_ops.append((operation, str(e))) if not operation.data.get("allow_partial", False): # On failure, rollback and return. await self.rollback_transaction() return BatchResult( success=False, operations_completed=completed, failed_operations=failed_ops, error_message=f"Operation failed: {str(e)}", ) # Commit the transaction (persisting all changes) or report partial success. await self.commit_transaction() if failed_ops: return BatchResult( success=True, operations_completed=completed, failed_operations=failed_ops, error_message="Some operations failed", ) else: return BatchResult( success=True, operations_completed=completed, failed_operations=[], ) except Exception as e: if self._in_transaction: await self.rollback_transaction() return BatchResult( success=False, operations_completed=0, failed_operations=[], error_message=f"Batch execution failed: {str(e)}", )