Knowledge Graph Memory Server

  • optimized_memory_mcp_server
import asyncio import aiosqlite import logging import os from typing import List, Optional, Dict, Any, AsyncGenerator from pathlib import Path from urllib.parse import urlparse from contextlib import asynccontextmanager from .interfaces import Entity, Relation from .exceptions import EntityNotFoundError, EntityAlreadyExistsError logger = logging.getLogger(__name__) def _sanitize_input(value: str) -> str: """Sanitize input to prevent SQL injection""" return value.replace("'", "''") class OptimizedSQLiteManager: def __init__(self, database_url: str, echo: bool = False): """Initialize SQLite manager with database path extracted from URL.""" parsed_url = urlparse(database_url) if not parsed_url.path: raise ValueError("Database path not specified in URL") # For absolute paths, keep them as is if parsed_url.path.startswith('/'): self.db_path = parsed_url.path else: # For relative paths, handle as before path = parsed_url.path.lstrip('/') if '/' in path: # If path contains directories self.db_path = str(Path(path).absolute()) os.makedirs(os.path.dirname(self.db_path), exist_ok=True) else: # Simple filename in current directory self.db_path = path self.echo = echo self._pool: List[aiosqlite.Connection] = [] self._pool_size = 5 self._pool_semaphore = asyncio.Semaphore(self._pool_size) @asynccontextmanager async def _get_connection(self) -> AsyncGenerator[aiosqlite.Connection, None]: """Get a connection from the pool or create a new one.""" async with self._pool_semaphore: if not self._pool: conn = await aiosqlite.connect(self.db_path) await conn.execute("PRAGMA journal_mode=WAL") await conn.execute("PRAGMA synchronous=NORMAL") conn.row_factory = aiosqlite.Row else: conn = self._pool.pop() try: yield conn finally: self._pool.append(conn) @asynccontextmanager async def _transaction(self, conn: aiosqlite.Connection): """Manage database transactions.""" try: await conn.execute("BEGIN") yield await conn.commit() except Exception: await conn.rollback() raise async def initialize(self): """Initialize database schema.""" async with self._get_connection() as conn: async with self._transaction(conn): # Create entities table await conn.execute(""" CREATE TABLE IF NOT EXISTS entities ( name TEXT PRIMARY KEY, entity_type TEXT NOT NULL, observations TEXT NOT NULL ) """) # Create relations table await conn.execute(""" CREATE TABLE IF NOT EXISTS relations ( from_entity TEXT NOT NULL, to_entity TEXT NOT NULL, relation_type TEXT NOT NULL, PRIMARY KEY (from_entity, to_entity, relation_type), FOREIGN KEY (from_entity) REFERENCES entities(name), FOREIGN KEY (to_entity) REFERENCES entities(name) ) """) # Create indices await conn.execute("CREATE INDEX IF NOT EXISTS idx_entity_type ON entities(entity_type)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_from_entity ON relations(from_entity)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_to_entity ON relations(to_entity)") async def cleanup(self): """Clean up database connections in the pool.""" for conn in self._pool: await conn.close() self._pool.clear() async def create_entities(self, entities: List[Dict[str, Any]], batch_size: int = 1000) -> List[Dict[str, Any]]: """Create multiple new entities in the database using batch processing.""" created_entities = [] async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(entities), batch_size): batch = entities[i:i + batch_size] entity_objects = [Entity.from_dict(e) for e in batch] # Validate entities before insertion for entity in entity_objects: # Check if entity already exists cursor = await conn.execute( "SELECT 1 FROM entities WHERE name = ?", (_sanitize_input(entity.name),) ) if await cursor.fetchone(): raise EntityAlreadyExistsError(entity.name) # Insert batch await conn.executemany( "INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)", [(e.name, e.entityType, ','.join(e.observations)) for e in entity_objects] ) created_entities.extend([e.to_dict() for e in entity_objects]) return created_entities async def create_relations(self, relations: List[Dict[str, Any]], batch_size: int = 1000) -> List[Dict[str, Any]]: """Create multiple new relations in the database using batch processing.""" created_relations = [] async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(relations), batch_size): batch = relations[i:i + batch_size] relation_objects = [Relation.from_dict(r) for r in batch] # Verify all entities exist before batch insertion for relation in relation_objects: cursor = await conn.execute( "SELECT 1 FROM entities WHERE name = ?", (_sanitize_input(relation.from_),) ) if not await cursor.fetchone(): raise EntityNotFoundError(relation.from_) cursor = await conn.execute( "SELECT 1 FROM entities WHERE name = ?", (_sanitize_input(relation.to),) ) if not await cursor.fetchone(): raise EntityNotFoundError(relation.to) # Insert batch await conn.executemany( """ INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?) ON CONFLICT DO NOTHING """, [(r.from_, r.to, r.relationType) for r in relation_objects] ) created_relations.extend([r.to_dict() for r in relation_objects]) return created_relations async def read_graph(self) -> Dict[str, List[Dict[str, Any]]]: """Read the entire graph and return serializable format.""" async with self._get_connection() as conn: # Get all entities cursor = await conn.execute("SELECT * FROM entities") rows = await cursor.fetchall() entities = [] for row in rows: entity = Entity( name=row['name'], entityType=row['entity_type'], observations=row['observations'].split(',') if row['observations'] else [] ) entities.append(entity.to_dict()) # Get all relations cursor = await conn.execute("SELECT * FROM relations") rows = await cursor.fetchall() relations = [] for row in rows: relation = Relation( from_=row['from_entity'], to=row['to_entity'], relationType=row['relation_type'] ) relations.append(relation.to_dict()) return {"entities": entities, "relations": relations} async def add_observations(self, observations: List[Dict[str, Any]], batch_size: int = 1000) -> Dict[str, List[str]]: """Add new observations to existing entities using batch processing.""" added_observations = {} async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(observations), batch_size): batch = observations[i:i + batch_size] for obs in batch: entity_name = _sanitize_input(obs["entityName"]) new_contents = obs["contents"] # Check if entity exists cursor = await conn.execute( "SELECT observations FROM entities WHERE name = ?", (entity_name,) ) result = await cursor.fetchone() if not result: raise EntityNotFoundError(entity_name) # Get current observations and add new ones current_obs = result['observations'].split(',') if result['observations'] else [] current_obs.extend(new_contents) # Update entity with new observations await conn.execute( "UPDATE entities SET observations = ? WHERE name = ?", (','.join(current_obs), entity_name) ) added_observations[entity_name] = new_contents return added_observations async def delete_entities(self, entityNames: List[str], batch_size: int = 1000) -> None: """Remove entities and their relations using batch processing.""" async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(entityNames), batch_size): batch = entityNames[i:i + batch_size] sanitized_names = [_sanitize_input(name) for name in batch] # Delete relations involving the entities placeholders = ','.join(['?' for _ in sanitized_names]) await conn.execute( f""" DELETE FROM relations WHERE from_entity IN ({placeholders}) OR to_entity IN ({placeholders}) """, sanitized_names * 2 ) # Delete the entities await conn.execute( f"DELETE FROM entities WHERE name IN ({placeholders})", sanitized_names ) async def delete_observations(self, deletions: List[Dict[str, Any]], batch_size: int = 1000) -> None: """Remove specific observations from entities using batch processing.""" async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(deletions), batch_size): batch = deletions[i:i + batch_size] for deletion in batch: entity_name = _sanitize_input(deletion["entityName"]) to_delete = set(deletion["observations"]) # Get current observations cursor = await conn.execute( "SELECT observations FROM entities WHERE name = ?", (entity_name,) ) result = await cursor.fetchone() if result: current_obs = result['observations'].split(',') if result['observations'] else [] # Remove specified observations updated_obs = [obs for obs in current_obs if obs not in to_delete] # Update entity with remaining observations await conn.execute( "UPDATE entities SET observations = ? WHERE name = ?", (','.join(updated_obs), entity_name) ) async def delete_relations(self, relations: List[Dict[str, Any]], batch_size: int = 1000) -> None: """Remove specific relations from the graph using batch processing.""" async with self._get_connection() as conn: async with self._transaction(conn): for i in range(0, len(relations), batch_size): batch = relations[i:i + batch_size] relation_objects = [Relation.from_dict(r) for r in batch] # Delete relations in batch await conn.executemany( """ DELETE FROM relations WHERE from_entity = ? AND to_entity = ? AND relation_type = ? """, [(r.from_, r.to, r.relationType) for r in relation_objects] ) async def open_nodes(self, names: List[str]) -> Dict[str, List[Dict[str, Any]]]: """Retrieve specific nodes by name and their relations in serializable format.""" async with self._get_connection() as conn: sanitized_names = [_sanitize_input(name) for name in names] placeholders = ','.join('?' * len(sanitized_names)) # Get requested entities cursor = await conn.execute( f"SELECT * FROM entities WHERE name IN ({placeholders})", sanitized_names ) rows = await cursor.fetchall() entities = [] for row in rows: entity = Entity( name=row['name'], entityType=row['entity_type'], observations=row['observations'].split(',') if row['observations'] else [] ) entities.append(entity.to_dict()) # Get relations between requested entities cursor = await conn.execute( f""" SELECT * FROM relations WHERE from_entity IN ({placeholders}) AND to_entity IN ({placeholders}) """, sanitized_names * 2 ) rows = await cursor.fetchall() relations = [] for row in rows: relation = Relation( from_=row['from_entity'], to=row['to_entity'], relationType=row['relation_type'] ) relations.append(relation.to_dict()) return {"entities": entities, "relations": relations} async def backup_database(self) -> str: """Create a backup of the database using VACUUM INTO. Returns: str: Path to the backup file """ backup_path = f"{self.db_path}.backup" async with self._get_connection() as conn: await conn.execute(f"VACUUM INTO '{backup_path}'") return backup_path async def search_nodes(self, query: str) -> Dict[str, List[Dict[str, Any]]]: """Search for nodes and return serializable format.""" if not query: raise ValueError("Search query cannot be empty") async with self._get_connection() as conn: search_pattern = f"%{_sanitize_input(query)}%" # Search entities cursor = await conn.execute( """ SELECT * FROM entities WHERE name LIKE ? OR entity_type LIKE ? OR observations LIKE ? """, (search_pattern, search_pattern, search_pattern) ) rows = await cursor.fetchall() entities = [] entity_names = set() for row in rows: entity = Entity( name=row['name'], entityType=row['entity_type'], observations=row['observations'].split(',') if row['observations'] else [] ) entities.append(entity.to_dict()) entity_names.add(entity.name) if entity_names: # Get related relations placeholders = ','.join('?' * len(entity_names)) cursor = await conn.execute( f""" SELECT * FROM relations WHERE from_entity IN ({placeholders}) AND to_entity IN ({placeholders}) """, list(entity_names) * 2 ) rows = await cursor.fetchall() relations = [] for row in rows: relation = Relation( from_=row['from_entity'], to=row['to_entity'], relationType=row['relation_type'] ) relations.append(relation.to_dict()) else: relations = [] return {"entities": entities, "relations": relations}