#!/usr/bin/env python3
"""
LongMemEval Official Evaluation Harness for OMEGA.
Runs OMEGA against the real LongMemEval_S dataset (500 questions, ~40 sessions
each) from UCLA/Tencent (Wang et al., 2024). Produces hypothesis JSONL compatible
with the official grading script.
Every system on the LongMemEval leaderboard (Mastra 94.87%, Hindsight 91.4%,
Zep 71.2%) runs this exact protocol: ingest → retrieve → generate → GPT-4o judge.
Usage:
# Full run
python scripts/longmemeval_official.py --api-key $OPENAI_API_KEY
# Test with 5 questions
python scripts/longmemeval_official.py --questions 5 --verbose
# Retrieve-only (no GPT-4o cost)
python scripts/longmemeval_official.py --dry-run --verbose
# Grade existing hypothesis file
python scripts/longmemeval_official.py --grade-only longmemeval_hypothesis.jsonl
# Resume interrupted run
python scripts/longmemeval_official.py --resume
"""
import argparse
import json
import os
import re
import sys
import tempfile
import time
import urllib.request
from datetime import datetime, timedelta
from pathlib import Path
# Ensure omega is importable
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# ── Feature toggles (set from CLI args in main()) ─────────────────────────
ENABLE_RERANK = False # Disabled — MS-MARCO cross-encoder hurts conversational memory retrieval
ENABLE_QUERY_EXPANSION = True
ENABLE_RECENCY_BOOST = True
ENABLE_COMPRESSION = False # Disabled by default — proven harmful in testing
# Provider registry: model prefix → config for OpenAI-compatible or Anthropic
PROVIDER_REGISTRY = {
"gpt-": {"base_url": None, "env_key": "OPENAI_API_KEY"},
"o1": {"base_url": None, "env_key": "OPENAI_API_KEY"},
"o3": {"base_url": None, "env_key": "OPENAI_API_KEY"},
"o4": {"base_url": None, "env_key": "OPENAI_API_KEY"},
"gemini-": {
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
"env_key": "GOOGLE_API_KEY",
},
"grok-": {"base_url": "https://api.x.ai/v1", "env_key": "XAI_API_KEY"},
"llama-": {"base_url": "https://api.groq.com/openai/v1", "env_key": "GROQ_API_KEY"},
"meta-llama/": {
"base_url": "https://api.groq.com/openai/v1",
"env_key": "GROQ_API_KEY",
},
"qwen/": {"base_url": "https://api.groq.com/openai/v1", "env_key": "GROQ_API_KEY"},
"claude-": {"provider": "anthropic", "env_key": "ANTHROPIC_API_KEY"},
}
def _resolve_provider(model: str) -> dict:
"""Look up provider config from model name prefix."""
for prefix, config in PROVIDER_REGISTRY.items():
if model.startswith(prefix):
return config
raise ValueError(
f"Unknown model '{model}'. Supported prefixes: "
+ ", ".join(PROVIDER_REGISTRY.keys())
)
def _call_llm(
messages: list[dict],
model: str,
max_tokens: int = 256,
temperature: float = 0,
api_key: str | None = None,
) -> str:
"""Call an LLM via the appropriate provider, return response text."""
config = _resolve_provider(model)
key = api_key or os.environ.get(config["env_key"])
if not key:
raise ValueError(
f"No API key for model '{model}'. "
f"Set ${config['env_key']} or pass --api-key."
)
if config.get("provider") == "anthropic":
import anthropic
client = anthropic.Anthropic(api_key=key)
# Convert OpenAI-style messages to Anthropic format
system = None
user_messages = []
for m in messages:
if m["role"] == "system":
system = m["content"]
else:
user_messages.append(m)
kwargs = {
"model": model,
"messages": user_messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
if system:
kwargs["system"] = system
response = client.messages.create(**kwargs)
return response.content[0].text.strip()
else:
import openai
client = openai.OpenAI(base_url=config["base_url"], api_key=key)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
DATASET_URL = (
"https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned"
"/resolve/main/longmemeval_s_cleaned.json"
)
CACHE_DIR = Path.home() / ".omega" / "benchmarks"
CACHE_FILE = CACHE_DIR / "longmemeval_s_cleaned.json"
# ── Category-aware RAG prompts ──────────────────────────────────────────────
# Best prompt per category determined empirically across 4 benchmark runs.
# Vanilla: best for SS-assistant (100%), SS-user (97.1%), abstention (93.3%)
RAG_PROMPT_VANILLA = """\
I will give you several notes from past conversations between you and a user. \
Please answer the question based on the relevant notes. \
If the question cannot be answered based on the provided notes, say so.
Notes from past conversations:
{sessions}
Current Date: {question_date}
Question: {question}
Answer:"""
# Enhanced-confident: best for knowledge-update (87.2%), temporal (70.7%),
# preference (50.0%). Recency + aggregation + confidence.
RAG_PROMPT_ENHANCED = """\
I will give you several notes from past conversations between you and a user, \
ordered from oldest to newest. Please answer the question based on the relevant notes. \
If the question cannot be answered based on the provided notes, say so.
You MUST follow this process for EVERY question:
STEP 1 — Scan ALL notes for mentions of the queried topic. List every note that \
discusses it, with its note number and date.
STEP 2 — If the topic appears in multiple notes, compare the values. The note \
with the LATEST date is the ONLY correct one. Earlier values are SUPERSEDED and WRONG.
STEP 3 — Answer using ONLY the value from the latest note.
CRITICAL rules:
- Notes are in chronological order (oldest first). Higher note numbers are more recent.
- For questions about current state (e.g., "what is my current X?", "how many times \
have I done Y?"), the answer ALWAYS comes from the LAST note mentioning that topic.
- If a quantity changes across notes (e.g., worn 4 times → worn 6 times), the \
LATEST number replaces all earlier ones. Do NOT add or average them.
- If the question references a role, title, or name that does NOT exactly match \
what appears in the notes, say the information is not enough to answer.
- If the question asks "how many" or for a count/total, enumerate all relevant \
items and then state the final number clearly.
- Give a direct, concise answer. Do not hedge if the evidence is clear.
Notes from past conversations:
{sessions}
Current Date: {question_date}
Question: {question}
Answer:"""
# Enhanced-cautious: best for multi-session (69.9%). Recency + aggregation
# but no confidence push (avoids noise-induced wrong answers).
RAG_PROMPT_MULTISESSION = """\
I will give you several notes from past conversations between you and a user, \
ordered from oldest to newest. Please answer the question based on the relevant notes. \
If the question cannot be answered based on the provided notes, say so.
Important:
- Notes are in chronological order. When the same fact appears in multiple \
notes with different values, always use the value from the MOST RECENT note.
- If the question asks "how many", for a count, or for a total:
1. You MUST list EVERY matching item individually, citing its source as [Note #].
2. VERIFY each item: re-read the question and confirm each item EXACTLY matches \
what was asked. If the question asks about "types of citrus fruits", only count \
distinct fruit types the user actually used, not every mention of citrus. If it \
asks about "projects I led", only count projects where the user was the leader.
3. REMOVE items that don't strictly match the question's criteria. But NEVER dismiss \
something the USER claims they did (bought, attended, downloaded, etc.) just because \
the assistant questioned whether it's real. The user's statement is ground truth.
4. After filtering, count the remaining items and state the total clearly.
5. For "how much total" questions: list each amount with its source [Note #], \
then sum them and state the total.
- When the same fact is UPDATED in a later note (e.g., a number changes from X to Y), \
use ONLY the latest value. The earlier value is superseded.
- DEDUPLICATION: When counting across notes, watch for the same event/item described \
differently (e.g., "cousin's wedding" and "Rachel's wedding at a vineyard" may be the \
same event). If two items could be the same, count them as ONE. Err on the side of \
merging duplicates rather than double-counting.
- For questions about an "increase", "decrease", or "change" in a quantity: you MUST find \
BOTH the starting value AND the ending value, then compute the DIFFERENCE. Do NOT report \
the final total as the increase. Example: if followers went from 250 to 350, the increase is 100.
- Do NOT skip notes. Scan every note for potential matches before answering.
- Give a direct, concise answer. Do not hedge if the evidence is clear.
- NEVER guess, estimate, or calculate values that are not explicitly stated in the notes. \
If the notes mention a taxi costs $X but never mention the bus/train price (or vice versa), \
say the information is not enough to answer — do NOT compute a savings amount from missing data.
Notes from past conversations:
{sessions}
Current Date: {question_date}
Question: {question}
Answer:"""
# Preference: best for single-session-preference. Focuses on personal info recall.
RAG_PROMPT_PREFERENCE = """\
I will give you several notes from past conversations between you and a user. \
Please answer the question based on the user's stated preferences, habits, and \
personal information found in these notes. \
If the question cannot be answered based on the provided notes, say so.
Important:
- Focus on what the user explicitly said about their preferences, likes, dislikes, \
habits, routines, and personal details.
- When the same preference appears in multiple notes with different values, always \
use the value from the MOST RECENT note (higher note number = more recent).
- If the question asks for a recommendation or suggestion, USE the user's stated \
preferences to tailor your response. Do NOT say you lack information if the notes \
contain ANY relevant preferences, interests, or habits — apply them creatively.
- Even if the notes don't mention the exact topic, look for RELATED preferences \
(e.g., if asked about hotels, use stated preferences about views, amenities, \
luxury vs budget, or location preferences from ANY context).
- When the user mentions a place, activity, or event, ALWAYS check if the notes \
contain a SPECIFIC PAST EXPERIENCE with that place/activity. If so, reference it \
directly (e.g., "You mentioned enjoying X when you visited Denver before" or \
"Given your experience with Y in high school").
- Your answer MUST reference at least one specific detail from the notes. Generic \
advice that could apply to anyone is WRONG. The answer should be clearly \
personalized — someone reading it should be able to tell it was written for this \
specific user.
- Give a direct, specific answer. Quote the user's own words when possible.
Notes from past conversations:
{sessions}
Current Date: {question_date}
Question: {question}
Answer:"""
# Temporal: best for temporal-reasoning. Date-focused instructions for counting
# durations and ordering events chronologically.
RAG_PROMPT_TEMPORAL = """\
I will give you several notes from past conversations between you and a user, \
ordered from oldest to newest. Each note has a date stamp. Please answer the \
question based on the relevant notes. \
If the question cannot be answered based on the provided notes, say so.
You MUST follow these steps for ALL time-based questions:
STEP 1 — Convert every relative date to an ABSOLUTE date:
For each event mentioned in the notes, write its absolute date. Convert ALL \
relative references using the note's own date stamp:
- "last Saturday" = the most recent Saturday BEFORE the note's date
- "yesterday" = the day before the note's date
- "two weeks ago" = 14 days before the note's date
- "last month" = the calendar month before the note's date
- "next Friday" = the first Friday AFTER the note's date
STEP 2 — Find ALL candidate events, not just the first match:
When the question asks about something at a specific time (e.g., "two weeks ago", \
"last Saturday"), scan ALL notes and list every event that could match both the \
time reference AND the event description. Do NOT stop at the first event near \
the target date.
STEP 3 — Select the best match by verifying BOTH date AND description:
- The event must match the question's description (e.g., "art event", "business \
milestone", "life event of a relative"). A nearby event of the wrong type is wrong.
- Among events matching the description, pick the one closest to the exact \
target date. Prefer events within ±2 days; only consider ±3-7 days if no closer match exists.
- If a note says "I went to X last week" and the note is dated near the target, \
resolve "last week" to find the EXACT event date, not the note date.
STEP 4 — Compute the answer using ONLY the absolute dates:
- For "how many days/weeks/months between X and Y": subtract the two absolute \
dates and convert to the requested unit.
- For ordering questions: list each event with its absolute date, then sort by \
date (earliest first).
- For "how many times" or counting: enumerate each matching event with its \
absolute date, then state the total count.
- For "when" questions: state the absolute date directly.
CRITICAL rules:
- RECOLLECTION ≠ ACTION: When a note says "I was thinking about X", "I remembered X", \
or "I was reminiscing about X", the event X did NOT happen on that note's date. \
The note's date is when the user RECALLED the event, not when it occurred. \
Only use notes where the user describes PERFORMING an action to date that action.
- Notes are in chronological order. When the same fact appears in multiple \
notes with different values, always use the value from the MOST RECENT note.
- Give a direct, concise answer. Do not hedge if the evidence is clear.
- Show your date arithmetic briefly before giving the final answer.
- If you can infer the answer by combining information across multiple notes, DO SO. \
Do not refuse to answer simply because no single note contains the complete answer.
- When a relative time reference (e.g., "last Saturday", "two weeks ago") appears \
in a note, ALWAYS resolve it to an absolute date using that note's date stamp \
before comparing to the question date.
- BEFORE saying "not enough information": re-read every note looking for SYNONYMS \
or INDIRECT references. "Investment for a competition" could be "bought tools for \
a contest." "Kitchen appliance" could be "smoker" or "grill." "Piece of jewelry" \
could be "ring" or "necklace." Try harder to match before abstaining.
Notes from past conversations:
{sessions}
Current Date: {question_date}
Question: {question}
Answer:"""
# Category → prompt mapping
_CATEGORY_PROMPT = {
"single-session-assistant": RAG_PROMPT_VANILLA,
"single-session-user": RAG_PROMPT_VANILLA,
"knowledge-update": RAG_PROMPT_ENHANCED,
"multi-session": RAG_PROMPT_MULTISESSION,
"temporal-reasoning": RAG_PROMPT_TEMPORAL,
"single-session-preference": RAG_PROMPT_PREFERENCE,
}
# Category → filter/generation config
_CATEGORY_CONFIG = {
# Vanilla categories: tighter filter (fewer results = less noise)
"single-session-assistant": {"min_rel": 0.15, "min_res": 2, "max_res": 10, "max_tokens": 512},
"single-session-user": {"min_rel": 0.12, "min_res": 3, "max_res": 12, "max_tokens": 512},
# Enhanced categories: adaptive filter, more tokens
"knowledge-update": {"min_rel": 0.15, "min_res": 3, "max_res": 15, "max_tokens": 2048},
"multi-session": {"min_rel": 0.08, "min_res": 4, "max_res": 20, "max_tokens": 2048},
"temporal-reasoning": {"min_rel": 0.10, "min_res": 5, "max_res": 20, "max_tokens": 2048},
"single-session-preference": {"min_rel": 0.12, "min_res": 3, "max_res": 10, "max_tokens": 2048},
}
_DEFAULT_CONFIG = {"min_rel": 0.15, "min_res": 3, "max_res": 10, "max_tokens": 512}
# ---------- Grading prompt templates (from official evaluate_qa.py) ----------
GRADE_PROMPTS = {
"default": """\
I will give you a question, a correct answer, and a response from a model. \
Please answer yes if the response contains the correct answer. Otherwise, answer no. \
If the response is equivalent to the correct answer or contains all the intermediate \
steps to get the correct answer, you should also answer yes. If the response only \
contains a subset of the information required by the answer, answer no.
Question: {question}
Correct Answer: {answer}
Model Response: {hypothesis}
Is the model response correct? Answer yes or no only.""",
"temporal-reasoning": """\
I will give you a question, a correct answer, and a response from a model. \
Please answer yes if the response contains the correct answer. Otherwise, answer no. \
If the response is equivalent to the correct answer or contains all the intermediate \
steps to get the correct answer, you should also answer yes. If the response only \
contains a subset of the information required by the answer, answer no. \
In addition, do not penalize off-by-one errors for the number of days. If the question \
asks for the number of days/weeks/months, etc., and the model makes off-by-one errors \
(e.g., predicting 19 days when the answer is 18), the model's response is still correct.
Question: {question}
Correct Answer: {answer}
Model Response: {hypothesis}
Is the model response correct? Answer yes or no only.""",
"knowledge-update": """\
I will give you a question, a correct answer, and a response from a model. \
Please answer yes if the response contains the correct answer. Otherwise, answer no. \
If the response contains some previous information along with an updated answer, the \
response should be considered as correct as long as the updated answer is the required answer.
Question: {question}
Correct Answer: {answer}
Model Response: {hypothesis}
Is the model response correct? Answer yes or no only.""",
"single-session-preference": """\
I will give you a question, a rubric for desired personalized response, and a response \
from a model. Please answer yes if the response satisfies the desired response. \
Otherwise, answer no. The model does not need to reflect all the points in the rubric. \
The response is correct as long as it recalls and utilizes the user's personal \
information correctly.
Question: {question}
Rubric: {answer}
Model Response: {hypothesis}
Is the model response correct? Answer yes or no only.""",
"abstention": """\
I will give you an unanswerable question, an explanation, and a response from a model. \
Please answer yes if the model correctly identifies the question as unanswerable. \
The model could say that the information is incomplete, or some other information is \
given but the asked information is not.
Question: {question}
Explanation: {answer}
Model Response: {hypothesis}
Does the model correctly identify the question as unanswerable? Answer yes or no only.""",
}
# ──────────────────────────── Dataset helpers ────────────────────────────────
def download_dataset() -> list:
"""Download LongMemEval_S dataset from HuggingFace, caching locally."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
if CACHE_FILE.exists():
print(f" Using cached dataset: {CACHE_FILE}")
with open(CACHE_FILE) as f:
return json.load(f)
print(" Downloading dataset from HuggingFace...")
urllib.request.urlretrieve(DATASET_URL, CACHE_FILE)
print(f" Cached at: {CACHE_FILE}")
with open(CACHE_FILE) as f:
return json.load(f)
def parse_longmemeval_date(date_str: str) -> str:
"""Parse LongMemEval date 'YYYY/MM/DD (Day) HH:MM' → ISO format."""
cleaned = re.sub(r"\s*\([A-Za-z]+\)\s*", " ", date_str).strip()
try:
dt = datetime.strptime(cleaned, "%Y/%m/%d %H:%M")
return dt.isoformat()
except ValueError:
return date_str
def format_session_text(turns: list) -> str:
"""Format a session's turns as natural language for OMEGA ingestion."""
lines = []
for turn in turns:
lines.append(f"{turn['role']}: {turn['content']}")
return "\n".join(lines)
def format_session_for_prompt(content: str, date_str: str, index: int) -> str:
"""Format a retrieved session as a Chain-of-Note block with source attribution."""
return (
f"[Note {index} | Date: {date_str}]\n"
f"{content}\n"
f"[End Note {index}]"
)
def answer_to_str(answer) -> str:
"""Convert mixed-type answer field to string."""
if isinstance(answer, list):
return ", ".join(str(a) for a in answer)
return str(answer)
# ──────────────────────────── Core pipeline ──────────────────────────────────
_FACT_EXTRACTION_PROMPT = """\
Extract 2-3 key facts from this conversation. Preserve exact dates, names, \
numbers, preferences, and opinions. Output as numbered list. Be brief.
Conversation:
{content}
Key facts:"""
def _extract_key_facts(content: str, api_key: str | None = None) -> str | None:
"""Extract key facts from session content using a cheap LLM.
Returns a string like "[Key Facts: 1) ... 2) ... 3) ...]" or None on failure.
"""
try:
facts = _call_llm(
messages=[{"role": "user", "content": _FACT_EXTRACTION_PROMPT.format(
content=content[:3000]
)}],
model="gpt-4.1-mini",
max_tokens=200,
temperature=0,
api_key=api_key,
)
return f"\n[Key Facts: {facts.strip()}]"
except Exception:
return None
def ingest_question(question_data: dict, tmpdir: str, api_key: str | None = None,
extract_facts: bool = False):
"""Create a fresh SQLiteStore and ingest all haystack sessions for one question.
If extract_facts is True, appends LLM-extracted key facts to each session
for better retrieval (fact-augmented keys, a la Hindsight TEMPR).
"""
from omega.sqlite_store import SQLiteStore
db_path = os.path.join(tmpdir, "bench.db")
os.environ["OMEGA_HOME"] = tmpdir
store = SQLiteStore(db_path=db_path)
sessions = question_data["haystack_sessions"]
session_ids = question_data["haystack_session_ids"]
dates = question_data["haystack_dates"]
for session_turns, sid, date_str in zip(sessions, session_ids, dates):
content = format_session_text(session_turns)
iso_date = parse_longmemeval_date(date_str)
# Fact-augmented keys: append extracted facts for better embedding/retrieval
if extract_facts:
facts = _extract_key_facts(content, api_key=api_key)
if facts:
content = content + facts
store.store(
content=content,
session_id=sid,
metadata={
"event_type": "session_summary",
"referenced_date": iso_date,
"priority": 3,
},
skip_inference=True,
)
return store
def _infer_temporal_range_anchored(query_text: str, anchor_date: str) -> tuple | None:
"""Infer a temporal range from relative time references in the question.
Resolves phrases like "N weeks/months/days ago" against the question's
anchor_date (not datetime.now()), returning an ISO (start, end) tuple
or None if no temporal signal is detected.
"""
# Parse the anchor date from LongMemEval format "YYYY/MM/DD (Day) HH:MM"
cleaned = re.sub(r"\s*\([A-Za-z]+\)\s*", " ", anchor_date).strip()
try:
anchor = datetime.strptime(cleaned, "%Y/%m/%d %H:%M")
except ValueError:
try:
anchor = datetime.fromisoformat(anchor_date)
except ValueError:
return None
# Word-to-number mapping for "four weeks ago" etc.
_WORD_TO_NUM = {
"one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
"eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14,
"fifteen": 15, "twenty": 20, "thirty": 30,
}
# Pattern: "last (Monday|Tuesday|...|Sunday|weekend)"
day_match = re.search(
r"last\s+(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday|weekend)",
query_text, re.IGNORECASE,
)
if day_match:
day_name = day_match.group(1).capitalize()
if day_name == "Weekend":
# "last weekend" → most recent Saturday before anchor
target_weekday = 5 # Saturday
else:
_DAY_MAP = {
"Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3,
"Friday": 4, "Saturday": 5, "Sunday": 6,
}
target_weekday = _DAY_MAP[day_name]
# Find the most recent target_weekday before anchor
days_back = (anchor.weekday() - target_weekday) % 7
if days_back == 0:
days_back = 7 # "last X" means the previous one, not today
target_date = anchor - timedelta(days=days_back)
start = (target_date - timedelta(days=2)).isoformat()
end = (target_date + timedelta(days=2)).isoformat()
return (start, end)
# Pattern: "N days/weeks/months/years ago" (digits or words)
m = re.search(r"(\d+|[a-z]+)\s+(day|week|month|year)s?\s+ago", query_text, re.IGNORECASE)
if m:
raw_n = m.group(1).lower()
if raw_n.isdigit():
n = int(raw_n)
elif raw_n in _WORD_TO_NUM:
n = _WORD_TO_NUM[raw_n]
else:
n = None
if n is not None:
unit = m.group(2).lower()
if unit == "day":
delta = timedelta(days=n)
elif unit == "week":
delta = timedelta(weeks=n)
elif unit == "month":
delta = timedelta(days=n * 30)
elif unit == "year":
delta = timedelta(days=n * 365)
else:
delta = None
if delta is not None:
# Window: from (anchor - delta - buffer) to (anchor - delta + buffer)
center = anchor - delta
buffer = max(delta * 0.25, timedelta(days=3))
start = (center - buffer).isoformat()
end = (center + buffer).isoformat()
return (start, end)
# Pattern: "between DATE and DATE" or "from DATE to DATE"
m = re.search(
r"(?:between|from)\s+(\d{4}[/-]\d{1,2}[/-]\d{1,2})\s+(?:and|to)\s+(\d{4}[/-]\d{1,2}[/-]\d{1,2})",
query_text,
re.IGNORECASE,
)
if m:
try:
d1 = datetime.strptime(m.group(1).replace("/", "-"), "%Y-%m-%d")
d2 = datetime.strptime(m.group(2).replace("/", "-"), "%Y-%m-%d")
start = (min(d1, d2) - timedelta(days=1)).isoformat()
end = (max(d1, d2) + timedelta(days=1)).isoformat()
return (start, end)
except ValueError:
pass
# Pattern: "last N days/weeks/months" (digits or words)
m = re.search(r"(?:last|past|previous)\s+(\d+|[a-z]+)\s+(day|week|month|year)s?", query_text, re.IGNORECASE)
if m:
raw_n = m.group(1).lower()
if raw_n.isdigit():
n = int(raw_n)
else:
n = _WORD_TO_NUM.get(raw_n)
if n is None:
return None
unit = m.group(2).lower()
if unit == "day":
delta = timedelta(days=n)
elif unit == "week":
delta = timedelta(weeks=n)
elif unit == "month":
delta = timedelta(days=n * 30)
elif unit == "year":
delta = timedelta(days=n * 365)
else:
return None
start = (anchor - delta - timedelta(days=1)).isoformat()
end = anchor.isoformat()
return (start, end)
# Pattern: specific month/year like "in January 2024" or "in 2024"
m = re.search(r"in\s+(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{4})", query_text, re.IGNORECASE)
if m:
month_name = m.group(1)
year = int(m.group(2))
month_num = datetime.strptime(month_name, "%B").month
start = datetime(year, month_num, 1) - timedelta(days=1)
if month_num == 12:
end = datetime(year + 1, 1, 1) + timedelta(days=1)
else:
end = datetime(year, month_num + 1, 1) + timedelta(days=1)
return (start.isoformat(), end.isoformat())
return None
def _resolve_relative_dates(query: str, anchor: datetime) -> list[str]:
"""Resolve ALL relative time references in a query to absolute date keywords.
Returns a list of date-keyword strings to append to the query for better
semantic matching. Covers every pattern that _infer_temporal_range_anchored
handles, so the query text benefits from the same date resolution.
"""
q_lower = query.lower()
resolved = []
_WORD_TO_NUM = {
"one": 1, "a": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
"eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14,
"fifteen": 15, "twenty": 20, "thirty": 30,
}
_DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
_MONTH_NAMES = [
"January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December",
]
# 1. "last (Monday|Tuesday|...|Sunday)" → resolved absolute date + day name
day_match = re.search(
r"last\s+(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday)",
query, re.IGNORECASE,
)
if day_match:
day_name = day_match.group(1).capitalize()
_DAY_MAP = {d: i for i, d in enumerate(_DAY_NAMES)}
target_weekday = _DAY_MAP[day_name]
days_back = (anchor.weekday() - target_weekday) % 7
if days_back == 0:
days_back = 7
target_date = anchor - timedelta(days=days_back)
resolved.append(f"{day_name} {target_date.strftime('%Y-%m-%d')} {target_date.strftime('%B %d')}")
# 2. "last weekend" → Saturday + Sunday dates
if "last weekend" in q_lower:
sat = anchor - timedelta(days=(anchor.weekday() + 2) % 7 or 7)
sun = sat + timedelta(days=1)
resolved.append(f"Saturday Sunday {sat.strftime('%Y-%m-%d')} {sun.strftime('%Y-%m-%d')}")
# 3. "yesterday" → absolute date
if "yesterday" in q_lower:
yest = anchor - timedelta(days=1)
resolved.append(f"{yest.strftime('%Y-%m-%d')} {yest.strftime('%B %d')}")
# 4. "last week" → date range
if "last week" in q_lower and "weekend" not in q_lower:
start = anchor - timedelta(days=anchor.weekday() + 7)
end = start + timedelta(days=6)
resolved.append(f"{start.strftime('%Y-%m-%d')} {end.strftime('%Y-%m-%d')}")
# 5. "N days/weeks/months/years ago" → resolved center date + month name
m = re.search(r"(\d+|[a-z]+)\s+(day|week|month|year)s?\s+ago", query, re.IGNORECASE)
if m:
raw_n = m.group(1).lower()
if raw_n.isdigit():
n = int(raw_n)
else:
n = _WORD_TO_NUM.get(raw_n)
if n is not None:
unit = m.group(2).lower()
if unit == "day":
delta = timedelta(days=n)
elif unit == "week":
delta = timedelta(weeks=n)
elif unit == "month":
delta = timedelta(days=n * 30)
elif unit == "year":
delta = timedelta(days=n * 365)
else:
delta = None
if delta:
center = anchor - delta
resolved.append(
f"{center.strftime('%Y-%m-%d')} {center.strftime('%B')} {center.strftime('%d')}"
)
# 6. "last N days/weeks/months" or "past N months" → date range
m = re.search(r"(?:last|past|previous)\s+(\d+|[a-z]+)\s+(day|week|month|year)s?", query, re.IGNORECASE)
if m:
raw_n = m.group(1).lower()
if raw_n.isdigit():
n = int(raw_n)
else:
n = _WORD_TO_NUM.get(raw_n)
if n is not None:
unit = m.group(2).lower()
if unit == "day":
delta = timedelta(days=n)
elif unit == "week":
delta = timedelta(weeks=n)
elif unit == "month":
delta = timedelta(days=n * 30)
elif unit == "year":
delta = timedelta(days=n * 365)
else:
delta = None
if delta:
start = anchor - delta
resolved.append(
f"{start.strftime('%Y-%m-%d')} {anchor.strftime('%Y-%m-%d')} "
f"{start.strftime('%B')} {anchor.strftime('%B')}"
)
# 7. "in [Month]" (without year) → resolve to month within recent context
m = re.search(
r"in\s+(January|February|March|April|May|June|July|August|September|October|November|December)\b",
query, re.IGNORECASE,
)
if m and not re.search(r"in\s+" + m.group(1) + r"\s+\d{4}", query, re.IGNORECASE):
month_name = m.group(1).capitalize()
month_num = _MONTH_NAMES.index(month_name) + 1
# Assume the most recent occurrence of that month before/at the anchor
if month_num <= anchor.month:
year = anchor.year
else:
year = anchor.year - 1
resolved.append(f"{month_name} {year} {year}-{month_num:02d}")
# 8. "in [Month] [Year]" → explicit month+year
m = re.search(
r"in\s+(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{4})",
query, re.IGNORECASE,
)
if m:
month_name = m.group(1).capitalize()
year = int(m.group(2))
month_num = _MONTH_NAMES.index(month_name) + 1
resolved.append(f"{month_name} {year} {year}-{month_num:02d}")
return resolved
def _expand_query(query: str, question_date: str | None = None) -> str:
"""Expand a query with temporal, entity, and counting signals for better retrieval.
- Temporal: resolves ALL relative dates to absolute date keywords
- Entity: extracts proper nouns as explicit search terms
- Counting: adds enumeration cues for aggregation questions
"""
expansions = []
# 1. Counting signals
q_lower = query.lower()
if any(sig in q_lower for sig in ("how many", "how much", "how often", "total number", "count")):
expansions.append("every instance all occurrences each time")
# 2. Temporal keyword expansion — resolve ALL relative dates
if question_date:
cleaned = re.sub(r"\s*\([A-Za-z]+\)\s*", " ", question_date).strip()
try:
anchor = datetime.strptime(cleaned, "%Y/%m/%d %H:%M")
except ValueError:
try:
anchor = datetime.fromisoformat(question_date)
except ValueError:
anchor = None
if anchor:
resolved = _resolve_relative_dates(query, anchor)
expansions.extend(resolved)
# 3. Entity extraction: proper nouns (capitalized words not at sentence start)
_COMMON = {
"I", "The", "A", "An", "My", "What", "When", "Where", "Who", "How",
"Which", "Why", "Do", "Does", "Did", "Is", "Are", "Was", "Were",
"Have", "Has", "Had", "Can", "Could", "Would", "Should", "Will",
"If", "In", "On", "At", "To", "For", "Of", "And", "Or", "But",
"Not", "That", "This", "It", "He", "She", "They", "We", "You",
"Please", "Tell", "Me", "About",
}
words = re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", query)
entities = [w for w in words if w not in _COMMON and len(w) > 1]
if entities:
expansions.append(" ".join(entities))
if not expansions:
return query
return query + " " + " ".join(expansions)
_QUERY_AUGMENT_PROMPT = """\
Given this question from a user about their personal history, generate 2-3 alternative \
search queries that could find the answer in their conversation logs. Focus on SYNONYMS \
and RELATED TERMS that the user might have used instead of the exact words in the question.
Question: {question}
Output ONLY the alternative search queries, one per line. Keep each under 15 words."""
ENABLE_QUERY_AUGMENT = False # Controlled by --query-augment flag
def _augment_query(question: str, api_key: str | None = None) -> list[str]:
"""Use a cheap LLM to generate alternative search queries."""
if not ENABLE_QUERY_AUGMENT:
return []
try:
response = _call_llm(
messages=[{"role": "user", "content": _QUERY_AUGMENT_PROMPT.format(question=question)}],
model="gpt-4.1-mini",
max_tokens=100,
temperature=0.3,
api_key=api_key,
)
return [line.strip() for line in response.strip().split("\n") if line.strip()]
except Exception:
return []
def _boost_recency(results: list) -> list:
"""Boost relevance scores by recency for knowledge-update questions.
Newer sessions get a multiplicative boost so the most recent fact about
a topic ranks higher than older (potentially outdated) mentions.
"""
# Collect all dates to establish range
dates = []
for r in results:
if r.metadata:
d = r.metadata.get("referenced_date", "")
if d:
dates.append(d)
if not dates:
return results
# Sort dates to find range
dates.sort()
try:
earliest = datetime.fromisoformat(dates[0])
latest = datetime.fromisoformat(dates[-1])
except (ValueError, IndexError):
return results
span = (latest - earliest).total_seconds()
if span <= 0:
return results
for r in results:
d = (r.metadata or {}).get("referenced_date", "")
if d:
try:
t = datetime.fromisoformat(d)
# Recency factor: 1.0 (oldest) to 1.5 (newest)
# Increased from 0.3 to 0.5 — KU failures show the latest note
# often ranks below older notes that have higher semantic match.
frac = (t - earliest).total_seconds() / span
r.relevance = (r.relevance or 0) * (1.0 + 0.5 * frac)
except ValueError:
pass
# Re-sort by boosted relevance (highest first)
results.sort(key=lambda r: r.relevance or 0, reverse=True)
return results
def retrieve_context(
store,
question: str,
limit: int = 5,
question_date: str | None = None,
question_type: str | None = None,
api_key: str | None = None,
) -> list:
"""Retrieve top-K sessions for a question using OMEGA's hybrid search.
Infers a temporal range from relative time references in the question
text (for ALL question types, not just temporal-reasoning) and passes
it to the store for date-aware scoring. A secondary unfiltered retrieval
is merged to avoid missing context outside the window.
"""
# Expand query with temporal/entity/counting signals for better retrieval
expanded = _expand_query(question, question_date) if ENABLE_QUERY_EXPANSION else question
# Temporal range filtering — apply to ALL categories when a temporal
# signal is detected, not just temporal-reasoning. Multi-session questions
# like "cashback last Thursday" or "activities in December" also benefit.
temporal_range = None
if question_date:
temporal_range = _infer_temporal_range_anchored(question, question_date)
if temporal_range:
# Primary: temporal-filtered retrieval with expanded query
primary = store.query(
expanded, limit=limit, include_infrastructure=True,
temporal_range=temporal_range, query_hint=question_type,
)
# Secondary: unfiltered retrieval to catch context outside the window
secondary = store.query(
expanded, limit=limit, include_infrastructure=True,
query_hint=question_type,
)
# Tertiary: original question without expansion (catches different
# semantic matches that the date-keyword expansion might dilute)
tertiary = []
if expanded != question:
tertiary = store.query(
question, limit=limit, include_infrastructure=True,
query_hint=question_type,
)
# Merge: primary results take precedence, then fill from secondary + tertiary
seen_ids = {r.id for r in primary}
merged = list(primary)
for r in secondary + tertiary:
if r.id not in seen_ids:
merged.append(r)
seen_ids.add(r.id)
# Knowledge-update: still boost recency within temporal results
if ENABLE_RECENCY_BOOST and question_type == "knowledge-update" and merged:
merged = _boost_recency(merged)
return merged
results = store.query(
expanded, limit=limit, include_infrastructure=True,
query_hint=question_type,
)
# Multi-query: also search with original query if expanded differs
if expanded != question:
alt_results = store.query(
question, limit=limit, include_infrastructure=True,
query_hint=question_type,
)
seen_ids = {r.id for r in results}
for r in alt_results:
if r.id not in seen_ids:
results.append(r)
seen_ids.add(r.id)
# LLM-augmented query: run additional queries from semantic expansion
if ENABLE_QUERY_AUGMENT:
aug_queries = _augment_query(question, api_key=api_key)
seen_ids = {r.id for r in results}
for aq in aug_queries[:3]: # Max 3 augmented queries
aug_results = store.query(
aq, limit=limit // 2, include_infrastructure=True,
query_hint=question_type,
)
for r in aug_results:
if r.id not in seen_ids:
results.append(r)
seen_ids.add(r.id)
# Knowledge-update: boost recency so newest facts rank higher
if ENABLE_RECENCY_BOOST and question_type == "knowledge-update" and results:
results = _boost_recency(results)
return results
# ── Cross-encoder reranking ────────────────────────────────────────────────
_RERANKER = None
def _rerank_results(query: str, results: list, top_k: int = 20) -> list:
"""Rerank retrieval results using a cross-encoder for sharper relevance.
Uses ms-marco-MiniLM-L-6-v2 (same as Hindsight/Emergence).
Falls back to original order if sentence-transformers is unavailable.
"""
if not results:
return results
global _RERANKER
try:
if _RERANKER is None:
from sentence_transformers import CrossEncoder
_RERANKER = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
pairs = [(query, r.content[:512]) for r in results]
scores = _RERANKER.predict(pairs)
ranked = sorted(zip(results, scores), key=lambda x: x[1], reverse=True)
return [r for r, _s in ranked[:top_k]]
except ImportError:
print(" WARNING: sentence-transformers not installed, skipping reranking")
return results
def _filter_and_sort_results(retrieved: list, min_relevance: float = 0.15,
min_results: int = 3, max_results: int = 10) -> list:
"""Adaptive filtering: keep high-relevance results, with floor and ceiling.
Retrieves K=15 for recall, then trims noise before passing to the LLM.
Sessions are sorted chronologically (oldest first) so the recency prompt works.
"""
# Filter by relevance floor
strong = [r for r in retrieved if (r.relevance or 0) >= min_relevance]
# Ensure minimum coverage (fall back to top-N by score)
if len(strong) < min_results:
strong = sorted(retrieved, key=lambda r: r.relevance or 0, reverse=True)[:min_results]
# Cap to avoid overwhelming the LLM
if len(strong) > max_results:
strong = sorted(strong, key=lambda r: r.relevance or 0, reverse=True)[:max_results]
# Sort chronologically (oldest first) so LLM sees updates in order
def _date_key(r):
if r.metadata:
return r.metadata.get("referenced_date", "") or ""
return ""
strong.sort(key=_date_key)
return strong
# ── Session compression (Mastra-inspired) ─────────────────────────────────
_COMPRESSION_PROMPT = """\
Extract the 2-3 most important facts from this conversation snippet. \
Preserve exact dates, names, numbers, quantities, preferences, and opinions. \
Output as concise bullet points. Do NOT add information not present in the text.
Conversation:
{content}
Key facts:"""
# Token threshold: only compress when total session text is long enough to benefit
_COMPRESS_TOKEN_THRESHOLD = 4000
# Categories that benefit from compression (long contexts with noise)
_COMPRESS_CATEGORIES = {"multi-session", "knowledge-update", "temporal-reasoning", "single-session-preference"}
def _compress_sessions(
session_blocks: list[str],
question_type: str,
question_text: str = "",
model: str = "gpt-4o",
api_key: str | None = None,
) -> list[str]:
"""Compress session blocks into key facts using a cheap LLM call.
Only compresses when: (1) category benefits from it, (2) total text is long,
(3) question is NOT a counting/aggregation query (those need ALL details).
Uses gpt-4.1-mini for cost efficiency.
"""
if question_type not in _COMPRESS_CATEGORIES:
return session_blocks
# Skip compression for counting/aggregation questions — they need exhaustive detail
q_lower = question_text.lower()
if any(sig in q_lower for sig in (
"how many", "how much", "how often", "total number", "count",
"number of", "list all", "list every", "name all", "name every",
)):
return session_blocks
total_chars = sum(len(b) for b in session_blocks)
# ~4 chars per token rough estimate
if total_chars / 4 < _COMPRESS_TOKEN_THRESHOLD:
return session_blocks
# Use a cheap model for compression
compress_model = "gpt-4.1-mini"
compressed = []
for block in session_blocks:
# Extract the date header and content
lines = block.split("\n")
header = lines[0] if lines else "" # [Note N | Date: ...]
footer = lines[-1] if lines else "" # [End Note N]
content = "\n".join(lines[1:-1]) if len(lines) > 2 else block
try:
facts = _call_llm(
messages=[{"role": "user", "content": _COMPRESSION_PROMPT.format(content=content[:3000])}],
model=compress_model,
max_tokens=256,
temperature=0,
api_key=api_key,
)
compressed.append(f"{header}\n{facts}\n{footer}")
except Exception:
# On any error, keep original block
compressed.append(block)
return compressed
def generate_answer(
question_data: dict,
retrieved: list,
model: str = "gpt-4o",
api_key: str | None = None,
) -> str:
"""Generate answer using category-aware prompts and filtering."""
qtype = question_data["question_type"]
qid = question_data["question_id"]
is_abstention = qid.endswith("_abs")
# Select prompt: abstention uses vanilla (best at 93.3%)
if is_abstention:
rag_prompt = RAG_PROMPT_VANILLA
else:
rag_prompt = _CATEGORY_PROMPT.get(qtype, RAG_PROMPT_MULTISESSION)
# Select filter config: abstention uses vanilla-style tight filter
if is_abstention:
cfg = {"min_rel": 0.20, "min_res": 2, "max_res": 5, "max_tokens": 256}
else:
cfg = _CATEGORY_CONFIG.get(qtype, _DEFAULT_CONFIG)
# Cross-encoder reranking before filtering (sharpens relevance signal)
reranked = _rerank_results(question_data["question"], retrieved) if ENABLE_RERANK else retrieved
# Adaptive filter with category-specific params
filtered = _filter_and_sort_results(
reranked,
min_relevance=cfg["min_rel"],
min_results=cfg["min_res"],
max_results=cfg["max_res"],
)
session_blocks = []
for i, result in enumerate(filtered, 1):
date_str = "Unknown"
if result.metadata:
date_str = result.metadata.get("referenced_date", "Unknown")
session_blocks.append(
format_session_for_prompt(result.content, date_str, i)
)
# Session compression: disabled by default (--compress to enable).
# Testing showed it strips relevant facts — Mastra's approach works because they
# compress at ingest with a specialized observer, not post-hoc.
if ENABLE_COMPRESSION:
session_blocks = _compress_sessions(
session_blocks, qtype,
question_text=question_data["question"],
model=model, api_key=api_key,
)
prompt = rag_prompt.format(
sessions="\n\n".join(session_blocks),
question_date=question_data["question_date"],
question=question_data["question"],
)
return _call_llm(
messages=[{"role": "user", "content": prompt}],
model=model,
max_tokens=cfg["max_tokens"],
api_key=api_key,
)
def grade_answer(
question_data: dict,
hypothesis: str,
model: str = "gpt-4o",
api_key: str | None = None,
) -> bool:
"""Grade a hypothesis using the official LongMemEval prompt templates."""
qtype = question_data["question_type"]
qid = question_data["question_id"]
answer = answer_to_str(question_data["answer"])
is_abstention = qid.endswith("_abs")
if is_abstention:
template = GRADE_PROMPTS["abstention"]
elif qtype in GRADE_PROMPTS:
template = GRADE_PROMPTS[qtype]
else:
template = GRADE_PROMPTS["default"]
prompt = template.format(
question=question_data["question"],
answer=answer,
hypothesis=hypothesis,
)
result = _call_llm(
messages=[{"role": "user", "content": prompt}],
model=model,
max_tokens=10,
api_key=api_key,
)
return "yes" in result.lower()
# ──────────────────────────── Metrics ────────────────────────────────────────
def compute_metrics(graded: list, dataset: list) -> dict:
"""Compute per-type and overall metrics, matching official print_qa_metrics.py."""
ref = {d["question_id"]: d for d in dataset}
type_correct: dict[str, int] = {}
type_total: dict[str, int] = {}
abs_correct = 0
abs_total = 0
for r in graded:
qid = r["question_id"]
label = r.get("label", False)
qtype = ref[qid]["question_type"]
type_total[qtype] = type_total.get(qtype, 0) + 1
if label:
type_correct[qtype] = type_correct.get(qtype, 0) + 1
if qid.endswith("_abs"):
abs_total += 1
if label:
abs_correct += 1
print("\n" + "=" * 65)
print(" LongMemEval Official Benchmark Results (OMEGA)")
print("=" * 65)
type_accs = []
for qtype in sorted(type_total.keys()):
correct = type_correct.get(qtype, 0)
total = type_total[qtype]
acc = correct / total * 100 if total > 0 else 0
type_accs.append(acc)
bar = "#" * int(acc / 5) + "-" * (20 - int(acc / 5))
print(f" {qtype:30s} {correct:3d}/{total:3d} [{bar}] {acc:5.1f}%")
overall_correct = sum(type_correct.values())
overall_total = sum(type_total.values())
overall_acc = overall_correct / overall_total * 100 if overall_total > 0 else 0
task_avg = sum(type_accs) / len(type_accs) if type_accs else 0
abs_acc = abs_correct / abs_total * 100 if abs_total > 0 else 0
print(f"\n{'─' * 65}")
print(f" Overall Accuracy: {overall_correct}/{overall_total} = {overall_acc:.1f}%")
print(f" Task-Averaged Accuracy: {task_avg:.1f}%")
print(f" Abstention Accuracy: {abs_correct}/{abs_total} = {abs_acc:.1f}%")
print(f"{'─' * 65}")
return {"overall": overall_acc, "task_averaged": task_avg, "abstention": abs_acc}
def compute_retrieval_recall(retrieval_log: list) -> None:
"""Print retrieval recall stats from the run log."""
if not retrieval_log:
return
type_hits: dict[str, int] = {}
type_total: dict[str, int] = {}
for entry in retrieval_log:
qtype = entry["question_type"]
hit = entry["retrieval_hit"]
type_total[qtype] = type_total.get(qtype, 0) + 1
if hit:
type_hits[qtype] = type_hits.get(qtype, 0) + 1
total_hits = sum(type_hits.values())
total = sum(type_total.values())
print(f"\n{'─' * 65}")
print(" Retrieval Recall (evidence session in top-K)")
print(f"{'─' * 65}")
for qtype in sorted(type_total.keys()):
hits = type_hits.get(qtype, 0)
tot = type_total[qtype]
pct = hits / tot * 100 if tot > 0 else 0
print(f" {qtype:30s} {hits:3d}/{tot:3d} {pct:5.1f}%")
overall_pct = total_hits / total * 100 if total > 0 else 0
print(f" {'OVERALL':30s} {total_hits:3d}/{total:3d} {overall_pct:5.1f}%")
print()
# ──────────────────────────── Main ───────────────────────────────────────────
def run_grade_only(args, dataset: list) -> int:
"""Grade an existing hypothesis file."""
hyps = []
with open(args.grade_only) as f:
for line in f:
line = line.strip()
if line:
hyps.append(json.loads(line))
print(f" Loaded {len(hyps)} hypotheses from {args.grade_only}")
ref = {d["question_id"]: d for d in dataset}
graded = []
for i, h in enumerate(hyps):
qid = h["question_id"]
label = grade_answer(ref[qid], h["hypothesis"], args.model, args.api_key)
h["label"] = label
graded.append(h)
if args.verbose:
status = "PASS" if label else "FAIL"
print(f" [{i + 1}/{len(hyps)}] [{status}] {qid}")
else:
print(f"\r Grading [{i + 1}/{len(hyps)}]", end="", flush=True)
out_path = args.grade_only + f".eval-results-{args.model}"
with open(out_path, "w") as f:
for g in graded:
f.write(json.dumps(g) + "\n")
print(f"\n Grading results written to: {out_path}")
compute_metrics(graded, dataset)
return 0
def run_generation(args, dataset: list) -> int:
"""Main generation pipeline: ingest → retrieve → generate → output."""
questions = dataset
if getattr(args, "question_ids", None):
# Filter to specific question IDs for spot-checking
target_ids = set(args.question_ids)
questions = [q for q in dataset if q["question_id"] in target_ids]
if not questions:
print(f" ERROR: No questions matched IDs: {args.question_ids}")
return 1
print(f" Spot-checking {len(questions)} specific question(s)")
elif args.questions > 0:
questions = dataset[: args.questions]
# Resume mode
completed_ids: set[str] = set()
if args.resume and Path(args.output).exists():
with open(args.output) as f:
for line in f:
line = line.strip()
if line:
completed_ids.add(json.loads(line)["question_id"])
print(f" Resuming: {len(completed_ids)} questions already completed")
total = len(questions)
results = []
retrieval_log = []
start_time = time.time()
output_mode = "a" if args.resume else "w"
out_file = open(args.output, output_mode)
try:
for idx, q in enumerate(questions):
qid = q["question_id"]
if qid in completed_ids:
continue
elapsed = time.time() - start_time
done_count = idx + 1 - len(completed_ids)
rate = done_count / elapsed if elapsed > 0 else 0
remaining = total - idx - 1
eta = remaining / rate / 60 if rate > 0 else 0
if args.verbose:
print(f"\n[{idx + 1}/{total}] {qid} ({q['question_type']})")
print(f" Q: {q['question'][:100]}")
else:
print(
f"\r [{idx + 1}/{total}] {qid} (ETA: {eta:.0f}m) ",
end="",
flush=True,
)
# 1. Ingest into fresh temp store
with tempfile.TemporaryDirectory() as tmpdir:
store = ingest_question(
q, tmpdir,
api_key=args.api_key,
extract_facts=getattr(args, "extract_facts", False),
)
n_sessions = len(q["haystack_sessions"])
if args.verbose:
print(f" Ingested {n_sessions} sessions")
# 2. Retrieve (temporal-aware for temporal-reasoning questions)
# Boost K for questions that need broader recall
k = max(args.limit, 20) # Global K floor of 20
q_lower = q["question"].lower()
if any(sig in q_lower for sig in (
"how many", "how much", "how often", "total number",
"count", "number of",
)):
k = max(k, 45)
elif q.get("question_type") == "multi-session":
# Non-counting MS questions still benefit from more context
k = max(k, 25)
elif q.get("question_type") == "temporal-reasoning":
# Temporal questions need more context for event matching
k = max(k, 25)
retrieved = retrieve_context(
store, q["question"], limit=k,
question_date=q.get("question_date"),
question_type=q.get("question_type"),
api_key=args.api_key,
)
# Track retrieval recall
answer_sids = set(q.get("answer_session_ids", []))
retrieved_sids = set()
for r in retrieved:
if r.metadata:
retrieved_sids.add(r.metadata.get("session_id", ""))
hit = bool(answer_sids & retrieved_sids)
retrieval_log.append(
{
"question_id": qid,
"question_type": q["question_type"],
"retrieval_hit": hit,
"answer_sids": list(answer_sids),
"retrieved_sids": list(retrieved_sids),
}
)
if args.verbose:
print(f" Retrieved {len(retrieved)} results")
for j, r in enumerate(retrieved):
rel = r.relevance if r.relevance else 0
sid = r.metadata.get("session_id", "?") if r.metadata else "?"
print(f" #{j + 1} [{rel:.3f}] session={sid}")
print(f" Evidence sessions: {answer_sids}")
print(f" Retrieved sessions: {retrieved_sids}")
print(f" Retrieval hit: {'YES' if hit else 'NO'}")
# 3. Generate
if args.dry_run:
hypothesis = "[DRY RUN]"
else:
try:
hypothesis = generate_answer(
q, retrieved, args.model, args.api_key
)
except Exception as e:
print(f"\n ERROR generating {qid}: {e}")
hypothesis = f"[ERROR: {e}]"
if args.verbose:
print(f" Answer: {hypothesis[:120]}")
store.close()
# 4. Write JSONL
entry = {"question_id": qid, "hypothesis": hypothesis}
out_file.write(json.dumps(entry) + "\n")
out_file.flush()
results.append(entry)
except KeyboardInterrupt:
print(f"\n\n Interrupted after {len(results)} questions. Use --resume to continue.")
finally:
out_file.close()
elapsed = time.time() - start_time
print(f"\n\n Hypothesis file: {args.output}")
print(f" Questions processed: {len(results)}")
print(f" Time: {elapsed / 60:.1f} minutes")
# Retrieval recall summary
compute_retrieval_recall(retrieval_log)
# Optional grading
if args.grade and not args.dry_run:
print(" Grading hypotheses...")
ref = {d["question_id"]: d for d in dataset}
all_results = []
with open(args.output) as f:
for line in f:
line = line.strip()
if line:
all_results.append(json.loads(line))
graded = []
for i, h in enumerate(all_results):
qid = h["question_id"]
label = grade_answer(ref[qid], h["hypothesis"], args.model, args.api_key)
h["label"] = label
graded.append(h)
if args.verbose:
status = "PASS" if label else "FAIL"
print(f" [{i + 1}/{len(all_results)}] [{status}] {qid}")
else:
print(
f"\r Grading [{i + 1}/{len(all_results)}]",
end="",
flush=True,
)
out_path = args.output + f".eval-results-{args.model}"
with open(out_path, "w") as f:
for g in graded:
f.write(json.dumps(g) + "\n")
print(f"\n Grading results written to: {out_path}")
compute_metrics(graded, dataset)
return 0
def main():
parser = argparse.ArgumentParser(
description="OMEGA LongMemEval Official Evaluation Harness",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""\
examples:
%(prog)s --questions 10 --verbose # 10 questions with gpt-4o
%(prog)s --model gpt-4.1 --questions 10 --verbose # use GPT-4.1
%(prog)s --model claude-sonnet-4-5-20250929 --questions 10 # use Claude Sonnet
%(prog)s --model gemini-2.5-pro --questions 10 # use Gemini
%(prog)s --model grok-4-0709 --questions 10 # use Grok-4
%(prog)s --dry-run --verbose # retrieve only (no API cost)
%(prog)s --resume # continue interrupted run
%(prog)s --grade-only longmemeval_hypothesis.jsonl # grade existing file
""",
)
parser.add_argument(
"--api-key",
default=None,
help="API key override (default: auto-detect from env per provider)",
)
parser.add_argument(
"--output",
default="longmemeval_hypothesis.jsonl",
help="Output JSONL path (default: longmemeval_hypothesis.jsonl)",
)
parser.add_argument(
"--limit",
type=int,
default=15,
help="Top-K retrieval limit (default: 15)",
)
parser.add_argument(
"--questions",
type=int,
default=0,
help="Run only first N questions (0 = all, default: 0)",
)
parser.add_argument(
"--resume",
action="store_true",
help="Skip already-completed questions (append mode)",
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Print per-question details",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Ingest + retrieve only (no GPT-4o calls)",
)
parser.add_argument(
"--grade",
action="store_true",
help="Grade hypothesis file after generation",
)
parser.add_argument(
"--grade-only",
metavar="JSONL",
help="Grade existing hypothesis JSONL (skip generation)",
)
parser.add_argument(
"--model",
default="gpt-4o",
help="Model for generation and grading. Supported: gpt-*, claude-*, "
"gemini-*, grok-*, llama-*, meta-llama/* (default: gpt-4o)",
)
parser.add_argument(
"--extract-facts",
action="store_true",
help="Extract key facts at ingest time for better retrieval (adds LLM cost)",
)
parser.add_argument(
"--question-ids",
nargs="+",
metavar="QID",
help="Run only specific question IDs (for spot-checking failures)",
)
parser.add_argument(
"--rerank",
action="store_true",
help="Enable cross-encoder reranking (disabled by default — hurts conversational retrieval)",
)
parser.add_argument(
"--no-query-expansion",
action="store_true",
help="Disable query expansion (temporal/entity/counting signals)",
)
parser.add_argument(
"--no-recency-boost",
action="store_true",
help="Disable recency boosting for knowledge-update questions",
)
parser.add_argument(
"--compress",
action="store_true",
help="Enable session compression before prompting (experimental, may hurt)",
)
parser.add_argument(
"--query-augment",
action="store_true",
help="Enable LLM-based query augmentation for better retrieval (adds ~$5 cost)",
)
args = parser.parse_args()
# Validate API key availability
if not args.dry_run:
try:
config = _resolve_provider(args.model)
except ValueError as e:
print(f"Error: {e}")
sys.exit(1)
if not args.api_key and not os.environ.get(config["env_key"]):
print(
f"Error: No API key for model '{args.model}'. "
f"Set ${config['env_key']} or pass --api-key."
)
sys.exit(1)
# Apply feature toggles from CLI flags
global ENABLE_RERANK, ENABLE_QUERY_EXPANSION, ENABLE_RECENCY_BOOST, ENABLE_COMPRESSION, ENABLE_QUERY_AUGMENT
if args.rerank:
ENABLE_RERANK = True
if args.no_query_expansion:
ENABLE_QUERY_EXPANSION = False
if args.no_recency_boost:
ENABLE_RECENCY_BOOST = False
if args.compress:
ENABLE_COMPRESSION = True
if args.query_augment:
ENABLE_QUERY_AUGMENT = True
toggles = []
if ENABLE_RERANK:
toggles.append("rerank")
if ENABLE_QUERY_EXPANSION:
toggles.append("query-expand")
if ENABLE_RECENCY_BOOST:
toggles.append("recency-boost")
if ENABLE_COMPRESSION:
toggles.append("compress")
if ENABLE_QUERY_AUGMENT:
toggles.append("query-augment")
if getattr(args, "extract_facts", False):
toggles.append("extract-facts")
print(f" Features: {', '.join(toggles) if toggles else 'none'}")
# Load dataset
print("Loading LongMemEval_S dataset...")
dataset = download_dataset()
print(f" {len(dataset)} questions loaded")
# Dispatch
if args.grade_only:
return run_grade_only(args, dataset)
return run_generation(args, dataset)
if __name__ == "__main__":
sys.exit(main() or 0)