"""Isolated RAG Memory System for Murder Mystery MCP Server.
Key improvements over basic RAG:
1. PER-SUSPECT ISOLATION - Each suspect has their own index partition
2. QUERY CACHING - Don't re-search if no new data since last query
3. UPDATE TRACKING - Know when each partition was last updated
4. NO BLEEDING - Strict isolation between suspect conversations
Architecture:
- Separate index per suspect (no cross-contamination)
- Shared clue index (clues can reference multiple suspects)
- Cache layer with invalidation on new data
"""
import hashlib
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
logger = logging.getLogger(__name__)
# Lazy imports
_faiss_available = None
_embeddings = None
_vectorstore_class = None
def _check_faiss():
"""Check if FAISS is available."""
global _faiss_available, _embeddings, _vectorstore_class
if _faiss_available is not None:
return _faiss_available
try:
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
_vectorstore_class = FAISS
_embeddings = OpenAIEmbeddings
_faiss_available = True
logger.info("[MEMORY] FAISS available")
except ImportError as e:
_faiss_available = False
logger.warning("[MEMORY] FAISS not available: %s", e)
return _faiss_available
@dataclass
class CacheEntry:
"""A cached search result."""
query_hash: str
results: List[Tuple[str, Dict]]
created_at: datetime
index_version: int # Version when this cache was created
@dataclass
class IndexPartition:
"""A single partition of the index (per-suspect or shared)."""
name: str
vectorstore: Any = None
documents: List[Dict] = field(default_factory=list)
version: int = 0 # Incremented on every update
last_updated: Optional[datetime] = None
doc_count: int = 0
class GameMemory:
"""Isolated RAG memory with per-suspect partitions and caching.
Structure:
- _suspect_partitions: Dict[suspect_name -> IndexPartition]
- _clue_partition: IndexPartition for all clues
- _cache: Dict[cache_key -> CacheEntry]
Guarantees:
- Searching for suspect X ONLY returns X's statements
- Cache is invalidated when new data is added
- Cross-references use a separate search path with explicit context
"""
def __init__(self):
self._suspect_partitions: Dict[str, IndexPartition] = {}
self._clue_partition: Optional[IndexPartition] = None
self._cache: Dict[str, CacheEntry] = {}
self._embeddings = None
self._initialized = False
# Stats for debugging
self._cache_hits = 0
self._cache_misses = 0
self._total_searches = 0
@property
def is_available(self) -> bool:
return _check_faiss() and self._initialized
def initialize(self) -> bool:
"""Initialize the memory system."""
if not _check_faiss():
return False
try:
self._embeddings = _embeddings(model="text-embedding-3-small")
self._clue_partition = self._create_partition("clues")
self._initialized = True
logger.info("[MEMORY] Initialized successfully")
return True
except Exception as e:
logger.error("[MEMORY] Initialization failed: %s", e)
return False
def _create_partition(self, name: str) -> IndexPartition:
"""Create a new index partition."""
# Initialize with placeholder (FAISS requires at least one doc)
vectorstore = _vectorstore_class.from_texts(
[f"Index partition for {name} initialized."],
self._embeddings,
metadatas=[{"type": "system", "partition": name}]
)
return IndexPartition(
name=name,
vectorstore=vectorstore,
documents=[],
version=0,
last_updated=datetime.now()
)
def _get_suspect_partition(self, suspect: str) -> IndexPartition:
"""Get or create partition for a suspect."""
suspect_key = suspect.lower().strip()
if suspect_key not in self._suspect_partitions:
self._suspect_partitions[suspect_key] = self._create_partition(suspect_key)
logger.info("[MEMORY] Created partition for suspect: %s", suspect)
return self._suspect_partitions[suspect_key]
def _cache_key(self, partition_name: str, query: str) -> str:
"""Generate cache key for a query."""
content = f"{partition_name}:{query}"
return hashlib.md5(content.encode()).hexdigest()
def _invalidate_cache_for_partition(self, partition_name: str):
"""Invalidate all cache entries for a partition."""
keys_to_remove = [
k for k in self._cache
if k.startswith(partition_name) or partition_name in k
]
for key in keys_to_remove:
del self._cache[key]
if keys_to_remove:
logger.info("[MEMORY] Invalidated %d cache entries for %s",
len(keys_to_remove), partition_name)
# =========================================================================
# INDEXING (Adding new data)
# =========================================================================
def add_conversation(
self,
suspect: str,
question: str,
answer: str,
turn: int,
metadata: Optional[Dict] = None
) -> bool:
"""Add a conversation to a suspect's partition.
This is ISOLATED - only goes into this suspect's index.
Invalidates cache for this suspect.
"""
if not self.is_available:
return False
partition = self._get_suspect_partition(suspect)
try:
# Create document
doc_text = f"Q: {question[:200]}\nA: {answer[:500]}"
doc_metadata = {
"type": "conversation",
"suspect": suspect,
"turn": turn,
"question": question,
"answer": answer,
"timestamp": datetime.now().isoformat(),
**(metadata or {})
}
# Add to partition's vectorstore
t0 = time.perf_counter()
partition.vectorstore.add_texts([doc_text], metadatas=[doc_metadata])
embed_ms = (time.perf_counter() - t0) * 1000
# Update partition metadata
partition.documents.append({"text": doc_text, "metadata": doc_metadata})
partition.version += 1
partition.last_updated = datetime.now()
partition.doc_count += 1
# Invalidate cache for this suspect
self._invalidate_cache_for_partition(suspect.lower())
logger.info(
"[MEMORY] Added conversation with %s (turn %d, v%d) - %.0fms",
suspect, turn, partition.version, embed_ms
)
return True
except Exception as e:
logger.error("[MEMORY] Failed to add conversation: %s", e)
return False
def add_clue(
self,
clue_id: str,
description: str,
location: str,
significance: str,
turn: int,
related_suspect: Optional[str] = None
) -> bool:
"""Add a clue to the shared clue partition."""
if not self.is_available or not self._clue_partition:
return False
try:
doc_text = f"Clue at {location}: {description}. Significance: {significance}"
doc_metadata = {
"type": "clue",
"clue_id": clue_id,
"description": description,
"location": location,
"significance": significance,
"turn": turn,
"related_suspect": related_suspect,
"timestamp": datetime.now().isoformat()
}
t0 = time.perf_counter()
self._clue_partition.vectorstore.add_texts([doc_text], metadatas=[doc_metadata])
embed_ms = (time.perf_counter() - t0) * 1000
self._clue_partition.documents.append({"text": doc_text, "metadata": doc_metadata})
self._clue_partition.version += 1
self._clue_partition.last_updated = datetime.now()
self._clue_partition.doc_count += 1
# Invalidate clue cache
self._invalidate_cache_for_partition("clues")
logger.info(
"[MEMORY] Added clue '%s' (v%d) - %.0fms",
clue_id, self._clue_partition.version, embed_ms
)
return True
except Exception as e:
logger.error("[MEMORY] Failed to add clue: %s", e)
return False
def add_emotional_event(
self,
suspect: str,
text: str,
metadata: Dict,
turn: int
) -> bool:
"""Add an emotional event to a suspect's partition.
Emotional events are embedded alongside conversations
but tagged with type='emotional_event' for filtering.
"""
if not self.is_available:
return False
partition = self._get_suspect_partition(suspect)
try:
doc_metadata = {
"type": "emotional_event",
"suspect": suspect,
"turn": turn,
"timestamp": datetime.now().isoformat(),
**metadata
}
t0 = time.perf_counter()
partition.vectorstore.add_texts([text], metadatas=[doc_metadata])
embed_ms = (time.perf_counter() - t0) * 1000
partition.documents.append({"text": text, "metadata": doc_metadata})
partition.version += 1
partition.last_updated = datetime.now()
partition.doc_count += 1
# Invalidate cache
self._invalidate_cache_for_partition(suspect.lower())
logger.info(
"[MEMORY] Added emotional event for %s (turn %d, v%d) - %.0fms",
suspect, turn, partition.version, embed_ms
)
return True
except Exception as e:
logger.error("[MEMORY] Failed to add emotional event: %s", e)
return False
# =========================================================================
# RETRIEVAL (With caching and isolation)
# =========================================================================
def search_suspect(
self,
suspect: str,
query: str,
k: int = 3,
skip_cache: bool = False
) -> List[Tuple[str, Dict]]:
"""Search ONLY within a specific suspect's partition.
GUARANTEED ISOLATION: This ONLY searches that suspect's index.
Results are cached until new data is added for this suspect.
Args:
suspect: Suspect name (exact match)
query: Search query
k: Number of results
skip_cache: Force fresh search
Returns:
List of (text, metadata) tuples - ALL from this suspect only
"""
if not self.is_available:
return []
suspect_key = suspect.lower().strip()
self._total_searches += 1
# Check if partition exists
if suspect_key not in self._suspect_partitions:
logger.info("[MEMORY] No data for suspect: %s", suspect)
return []
partition = self._suspect_partitions[suspect_key]
# Check cache
cache_key = self._cache_key(suspect_key, query)
if not skip_cache and cache_key in self._cache:
cached = self._cache[cache_key]
# Valid if partition hasn't been updated since cache was created
if cached.index_version == partition.version:
self._cache_hits += 1
logger.info("[MEMORY] Cache HIT for %s (v%d)", suspect, partition.version)
return cached.results
# Cache miss - do actual search
self._cache_misses += 1
try:
t0 = time.perf_counter()
results = partition.vectorstore.similarity_search(query, k=k)
search_ms = (time.perf_counter() - t0) * 1000
# Convert to our format
result_list = [(r.page_content, r.metadata) for r in results]
# Cache the results
self._cache[cache_key] = CacheEntry(
query_hash=cache_key,
results=result_list,
created_at=datetime.now(),
index_version=partition.version
)
logger.info(
"[MEMORY] search_suspect(%s) - %.0fms, %d results (cache MISS)",
suspect, search_ms, len(result_list)
)
return result_list
except Exception as e:
logger.error("[MEMORY] Suspect search failed: %s", e)
return []
def search_clues(
self,
query: str,
k: int = 3,
skip_cache: bool = False
) -> List[Tuple[str, Dict]]:
"""Search the clue partition."""
if not self.is_available or not self._clue_partition:
return []
partition = self._clue_partition
cache_key = self._cache_key("clues", query)
# Check cache
if not skip_cache and cache_key in self._cache:
cached = self._cache[cache_key]
if cached.index_version == partition.version:
self._cache_hits += 1
return cached.results
self._cache_misses += 1
try:
t0 = time.perf_counter()
results = partition.vectorstore.similarity_search(query, k=k)
search_ms = (time.perf_counter() - t0) * 1000
result_list = [(r.page_content, r.metadata) for r in results]
self._cache[cache_key] = CacheEntry(
query_hash=cache_key,
results=result_list,
created_at=datetime.now(),
index_version=partition.version
)
logger.info("[MEMORY] search_clues - %.0fms, %d results", search_ms, len(result_list))
return result_list
except Exception as e:
logger.error("[MEMORY] Clue search failed: %s", e)
return []
def get_suspect_history(self, suspect: str) -> List[Dict]:
"""Get ALL conversations with a suspect (no semantic search).
Uses document list directly - guaranteed complete and isolated.
"""
suspect_key = suspect.lower().strip()
if suspect_key not in self._suspect_partitions:
return []
partition = self._suspect_partitions[suspect_key]
convos = [
doc["metadata"]
for doc in partition.documents
if doc["metadata"].get("type") == "conversation"
]
return sorted(convos, key=lambda x: x.get("turn", 0))
def has_new_data_since(self, suspect: str, since_version: int) -> bool:
"""Check if a suspect's partition has new data since a version.
Use this to avoid unnecessary searches.
"""
suspect_key = suspect.lower().strip()
if suspect_key not in self._suspect_partitions:
return False
return self._suspect_partitions[suspect_key].version > since_version
def get_partition_version(self, suspect: str) -> int:
"""Get current version of a suspect's partition."""
suspect_key = suspect.lower().strip()
if suspect_key not in self._suspect_partitions:
return 0
return self._suspect_partitions[suspect_key].version
# =========================================================================
# CROSS-REFERENCES (Explicit, controlled access)
# =========================================================================
def find_cross_references(
self,
about_suspect: str,
k: int = 3
) -> List[Tuple[str, str, Dict]]:
"""Find what OTHER suspects said about a specific suspect.
This explicitly searches OTHER partitions for mentions.
Returns (speaker_name, text, metadata) tuples.
"""
if not self.is_available:
return []
about_key = about_suspect.lower().strip()
results = []
# Search each OTHER suspect's partition
for suspect_key, partition in self._suspect_partitions.items():
if suspect_key == about_key:
continue # Skip self
try:
# Search for mentions of the target suspect
query = f"mentions {about_suspect} or about {about_suspect}"
search_results = partition.vectorstore.similarity_search(query, k=k)
for r in search_results:
# Double-check it actually mentions the suspect
if about_suspect.lower() in r.page_content.lower():
results.append((
r.metadata.get("suspect", suspect_key),
r.page_content,
r.metadata
))
except Exception as e:
logger.warning("[MEMORY] Cross-ref search failed for %s: %s", suspect_key, e)
return results[:k]
# =========================================================================
# CONTRADICTION DETECTION
# =========================================================================
async def find_contradictions(
self,
suspect: str,
new_statement: Optional[str] = None
) -> "ContradictionResult":
"""Analyze a suspect's statements for contradictions.
If new_statement is provided, checks it against past statements.
Otherwise, analyzes all past statements against each other.
"""
from .contradiction_detector import detect_contradictions
history = self.get_suspect_history(suspect)
if len(history) < 2:
return ContradictionResult()
return await detect_contradictions(history, new_statement)
# =========================================================================
# STATS AND DEBUGGING
# =========================================================================
def get_stats(self) -> Dict:
"""Get memory system statistics."""
return {
"initialized": self._initialized,
"suspect_partitions": len(self._suspect_partitions),
"clue_count": self._clue_partition.doc_count if self._clue_partition else 0,
"cache_entries": len(self._cache),
"cache_hits": self._cache_hits,
"cache_misses": self._cache_misses,
"cache_hit_rate": self._cache_hits / max(1, self._total_searches),
"partitions": {
name: {
"version": p.version,
"doc_count": p.doc_count,
"last_updated": p.last_updated.isoformat() if p.last_updated else None
}
for name, p in self._suspect_partitions.items()
}
}
def clear(self):
"""Clear all memory (for game reset)."""
self._suspect_partitions = {}
self._clue_partition = None
self._cache = {}
self._cache_hits = 0
self._cache_misses = 0
self._total_searches = 0
self._initialized = False
logger.info("[MEMORY] Cleared all data")
@dataclass
class ContradictionResult:
"""Result of contradiction analysis."""
has_contradictions: bool = False
contradictions: List[Dict] = field(default_factory=list)
# Each contradiction: {statement1, statement2, explanation}