import argparse
import asyncio
import json
import os
import time
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import OpenAI
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.text import Text
console = Console()
# Prices in USD per 1M tokens (Input, Output)
# As of Jan 2026 (approximate/examples)
MODEL_COSTS = {
"qwen/qwen-2.5-coder-32b-instruct": (0.07, 0.16),
"qwen/qwen-2.5-coder-7b-instruct": (0.03, 0.09),
"google/gemini-2.0-flash-001": (0.10, 0.40),
"google/gemini-2.0-flash-lite-001": (0.07, 0.30),
"google/gemini-2.5-flash-lite": (0.07, 0.30), # Added for compatibility with user query
"openai/gpt-4o-mini": (0.15, 0.60),
"openai/gpt-4o": (2.50, 10.00),
"anthropic/claude-sonnet-4.5": (3.00, 15.00),
"ollama_local": (0.0, 0.0),
"qwen2.5-coder-32k": (0.0, 0.0),
"qwen3:4b": (0.0, 0.0),
}
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""Calculate estimated cost in USD."""
cost_pair = (0.0, 0.0)
for m_key, costs in MODEL_COSTS.items():
if model == m_key or (m_key in model and "/" in m_key):
cost_pair = costs
break
in_price, out_price = cost_pair
return (input_tokens / 1_000_000 * in_price) + (output_tokens / 1_000_000 * out_price)
def count_tokens_approx(text: str) -> int:
"""A rough approximation: 4 chars per token."""
return len(text) // 4
def ingest_context(globs: List[str], repo_root: Optional[Path] = None) -> str:
"""Ingest files matching globs into a JSON string."""
if repo_root is None:
repo_root = Path(os.getcwd()).resolve()
from rlm_mcp_server.ingest import read_paths
files = read_paths(repo_root, globs, max_file_bytes=200_000, max_total_bytes=2_000_000)
return json.dumps(files, ensure_ascii=False, indent=2)
def _extract_first_json(text: str) -> Any:
"""Robust JSON extraction logic."""
s = (text or "").strip()
if s.startswith("```"):
lines = s.splitlines()
if len(lines) >= 2 and lines[-1].strip().startswith("```"):
s = "\n".join(lines[1:-1]).strip()
start = None
for i, ch in enumerate(s):
if ch in "[{":
start = i
break
if start is None:
raise ValueError("No JSON opener found")
open_ch = s[start]
close_ch = "]" if open_ch == "[" else "}"
depth = 0
in_str = False
esc = False
for j in range(start, len(s)):
c = s[j]
if in_str:
if esc: esc = False
elif c == "\\": esc = True
elif c == '"': in_str = False
else:
if c == '"': in_str = True
elif c == open_ch: depth += 1
elif c == close_ch:
depth -= 1
if depth == 0:
payload = s[start:j+1]
return json.loads(payload)
raise ValueError("Unterminated JSON payload")
def score_result(raw_text: str, structured_content: Optional[Any] = None, expected_type: Optional[type] = None) -> Dict[str, Any]:
"""Score the result based on robustness rubric."""
from rlm_mcp_server.validate import is_valid_result
scores = {
"parse_ok": 0,
"schema_ok": 0,
"is_expected_type": 0,
"no_prose": 0
}
parsed = None
if structured_content:
parsed = structured_content
scores["parse_ok"] = 1
scores["no_prose"] = 1
scores["schema_ok"] = 1 if is_valid_result(structured_content) else 0
else:
try:
parsed = _extract_first_json(raw_text)
scores["parse_ok"] = 1
scores["schema_ok"] = 1 if is_valid_result(parsed) else 0
trimmed = raw_text.strip()
is_pure_json = (trimmed.startswith("{") and trimmed.endswith("}")) or (trimmed.startswith("[") and trimmed.endswith("]"))
is_pure_fence = trimmed.startswith("```") and trimmed.endswith("```") and trimmed.count("```") == 2
if is_pure_json or is_pure_fence:
scores["no_prose"] = 1
except Exception:
scores["parse_ok"] = 0
if scores["parse_ok"] == 1:
if expected_type:
if isinstance(parsed, expected_type):
scores["is_expected_type"] = 1
else:
scores["is_expected_type"] = 1
return scores
def get_provider_config(preset: str) -> Tuple[str, Optional[str]]:
"""Get base URL and API key for a provider preset."""
if preset == "openrouter":
return "https://openrouter.ai/api/v1", os.environ.get("OPENROUTER_API_KEY")
elif preset == "openai":
return "https://api.openai.com/v1", os.environ.get("OPENAI_API_KEY")
elif preset == "ollama_local":
return "http://localhost:11434/v1", "ollama"
return "", None
def baseline_call_openai_compatible(
provider_preset: str,
model: str,
prompt: str,
query: str,
temperature: float = 0.0,
max_output_tokens: int = 1200,
) -> dict:
"""Standard one-shot call to an OpenAI-compatible API."""
base_url, api_key = get_provider_config(provider_preset)
if not api_key:
raise ValueError(f"API key for {provider_preset} not found")
client = OpenAI(base_url=base_url, api_key=api_key)
messages = [
{"role": "system", "content": "Answer the query using the provided JSON context. Return raw JSON only. Be concise."},
{"role": "user", "content": f"QUERY:\n{query}\n\nCONTEXT_JSON:\n{prompt}"},
]
t0 = time.time()
resp = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_output_tokens,
temperature=temperature
)
dt = time.time() - t0
usage = getattr(resp, "usage", None)
full_prompt = messages[0]["content"] + "\n" + messages[1]["content"]
return {
"answer": resp.choices[0].message.content,
"full_prompt": full_prompt,
"wall_time_sec": dt,
"usage": {
"input_tokens": getattr(usage, "prompt_tokens", 0) or count_tokens_approx(full_prompt),
"output_tokens": getattr(usage, "completion_tokens", 0) or count_tokens_approx(resp.choices[0].message.content),
"total_tokens": getattr(usage, "total_tokens", 0),
},
}
async def mcp_rlm_call(server_py: str, tool_args: dict) -> Tuple[dict, Optional[Any]]:
"""Call the RLM MCP tool and return the result envelope and structured content."""
params = StdioServerParameters(command="uv", args=["run", "python", "-u", server_py], env=os.environ.copy())
async with stdio_client(params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
res = await session.call_tool("solve", tool_args)
structured = None
text_content = ""
res_dump = res.model_dump() if hasattr(res, "model_dump") else {}
meta = res_dump.get("meta") or {}
if "structured_content" in meta:
structured = meta["structured_content"]
if res.content:
for item in res.content:
if hasattr(item, "text"):
text_content += item.text
if structured:
return structured, structured
try:
parsed_text = json.loads(text_content)
# If we got the full envelope, return it
if isinstance(parsed_text, dict) and "version" in parsed_text and "status" in parsed_text:
return parsed_text, None
return {"status": "ok", "answer": text_content, "answer_json": parsed_text}, None
except json.JSONDecodeError:
return {"status": "ok", "answer": text_content}, None
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--query", required=True)
ap.add_argument("--globs", nargs="+", required=True)
ap.add_argument("--provider_preset", default="openrouter")
ap.add_argument("--model", required=True)
ap.add_argument("--other-model", default=None, help="Optional recursion model.")
ap.add_argument("--server_py", default="rlm_mcp_server/server.py")
ap.add_argument("--environment", default="docker")
ap.add_argument("--dump-dir", help="Directory to save artifacts")
ap.add_argument("--temperature", type=float, default=0.0)
ap.add_argument("--baseline-max-output-tokens", type=int, default=1200)
ap.add_argument("--rlm-max-iterations", type=int, default=12)
args = ap.parse_args()
repo_root = Path(os.getcwd()).resolve()
with console.status("[bold green]Ingesting context..."):
context_str = ingest_context(args.globs, repo_root)
context_tokens_approx = count_tokens_approx(context_str)
# --- BASELINE ---
console.print(f"\n[bold blue]Running baseline ({args.model})...[/bold blue]")
with console.status("Waiting for baseline response..."):
baseline = baseline_call_openai_compatible(
args.provider_preset,
args.model,
context_str,
args.query,
temperature=args.temperature,
max_output_tokens=args.baseline_max_output_tokens
)
# --- RLM MCP ---
console.print(f"\n[bold magenta]Running RLM MCP ({args.model})...[/bold magenta]")
tool_args = {
"query": args.query,
"globs": args.globs,
"provider_preset": args.provider_preset,
"rlm": {
"model_name": args.model,
"other_model_name": args.other_model,
"timeout_sec": 1200,
"max_iterations": args.rlm_max_iterations,
},
"environment": args.environment,
"output": {"include_metrics": True},
"temperature": args.temperature
}
with console.status("Waiting for RLM response..."):
rlm_out, structured = asyncio.run(mcp_rlm_call(args.server_py, tool_args))
# --- RESULTS ---
expected_type = list if any(kw in args.query.lower() for kw in ["list", "array", "set", "all"]) else dict
baseline_scores = score_result(baseline["answer"], expected_type=expected_type)
rlm_ans_text = rlm_out.get("answer", "")
rlm_scores = score_result(rlm_ans_text, structured_content=structured, expected_type=expected_type)
table = Table(title="Benchmark Results (Consolidated)")
table.add_column("Metric", style="cyan")
table.add_column("Baseline", justify="right")
table.add_column("RLM", justify="right")
table.add_column("Delta", justify="right")
# Scoring rows
for key in ["parse_ok", "schema_ok", "is_expected_type", "no_prose"]:
b_val = baseline_scores[key]
r_val = rlm_scores[key]
table.add_row(key.replace("_", " ").title(), str(b_val), str(r_val), f"{r_val - b_val:+}")
# Usage rows
rlm_metrics = rlm_out.get("metrics", {})
rlm_tokens = rlm_metrics.get("token_usage", {})
base_usage = baseline["usage"]
# Cost
base_cost = calculate_cost(args.model, base_usage["input_tokens"], base_usage["output_tokens"])
rlm_cost = calculate_cost(args.model, rlm_tokens.get("input_tokens", 0), rlm_tokens.get("output_tokens", 0))
table.add_row("Cost ($)", f"${base_cost:.4f}", f"${rlm_cost:.4f}", f"${rlm_cost - base_cost:+.4f}")
# Tokens
table.add_row("Total Tokens", str(base_usage["total_tokens"]), str(rlm_tokens.get("total_tokens", 0)), f"{rlm_tokens.get('total_tokens', 0) - base_usage['total_tokens']:+}")
table.add_row("Input Tokens", str(base_usage["input_tokens"]), str(rlm_tokens.get("input_tokens", 0)), f"{rlm_tokens.get('input_tokens', 0) - base_usage['input_tokens']:+}")
table.add_row("Output Tokens", str(base_usage["output_tokens"]), str(rlm_tokens.get("output_tokens", 0)), f"{rlm_tokens.get('output_tokens', 0) - base_usage['output_tokens']:+}")
peak_rlm = rlm_metrics.get("peak_input_tokens", 0)
peak_rlm_str = str(peak_rlm) if peak_rlm > 0 else "0 (No sub-calls)"
table.add_row("Peak Prompt Tokens", str(base_usage["input_tokens"]), peak_rlm_str, f"{peak_rlm - base_usage['input_tokens']:+}" if peak_rlm > 0 else "N/A")
# Time
rlm_time = rlm_metrics.get("wall_time_sec", 0)
table.add_row("Wall Time (sec)", f"{baseline['wall_time_sec']:.2f}", f"{rlm_time:.2f}", f"{rlm_time - baseline['wall_time_sec']:.2f}")
console.print(table)
# Warnings
if base_usage["input_tokens"] == 4096:
console.print(Panel("[bold red]WARNING: Baseline input tokens exactly 4096. Truncation likely (Ollama default context limit).[/bold red]", title="Truncation Detected"))
elif base_usage["input_tokens"] < context_tokens_approx * 0.9:
console.print(Panel(f"[bold yellow]WARNING: Baseline input ({base_usage['input_tokens']}) is significantly less than estimated context ({context_tokens_approx}). Truncation likely.[/bold yellow]", title="Possible Truncation"))
if args.dump_dir:
d = Path(args.dump_dir); d.mkdir(parents=True, exist_ok=True)
(d / "baseline_answer.txt").write_text(baseline["answer"])
(d / "rlm_answer.txt").write_text(str(rlm_ans_text))
(d / "rlm_full_response.json").write_text(json.dumps(rlm_out, indent=2))
(d / "ingested_context.json").write_text(context_str)
console.print(f"[dim]Artifacts saved to {args.dump_dir}[/dim]")
if __name__ == "__main__":
main()