import asyncio
import json
import re
import os
import sys
import traceback
import signal
from pathlib import Path
from typing import Any, Optional, Dict, List
# Handle shutdown signals for clean exit
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
from mcp.server.fastmcp import FastMCP
from mcp.types import CallToolResult, TextContent
# -----------------------------------------------------------------------------
# RLM Monkeypatching
# -----------------------------------------------------------------------------
# We MUST monkeypatch before RLM imports or uses these functions.
# These patches are required for handling large files and relaxed answer detection.
try:
import rlm.utils.parsing
import rlm.core.rlm
# 1. Allow larger REPL outputs (default 20k -> 100k)
# This prevents truncation of large files in RLM history.
rlm.utils.parsing.format_iteration.__defaults__ = (100000,)
rlm.core.rlm.format_iteration = rlm.utils.parsing.format_iteration
# 2. Relaxed final answer detection (not just start of line)
def _relaxed_find_final_answer(text: str, environment: Any = None) -> Optional[str]:
if not text: return None
# Check for FINAL_VAR pattern
m_var = re.search(r"FINAL_VAR\((?P<name>.*?)\)", text, re.DOTALL | re.IGNORECASE)
if m_var:
var_name = m_var.group("name").strip().strip('"').strip("'")
if environment is not None:
try:
res = environment.execute_code(f"print(FINAL_VAR({var_name!r}))")
ans = res.stdout.strip() or res.stderr.strip()
if ans: return ans
except Exception:
pass
# Check for FINAL pattern
m_final = re.search(r"FINAL\((?P<content>.*?)\)", text, re.DOTALL | re.IGNORECASE)
if m_final:
return m_final.group("content").strip()
return None
rlm.utils.parsing.find_final_answer = _relaxed_find_final_answer
rlm.core.rlm.find_final_answer = _relaxed_find_final_answer
except Exception as e:
sys.stderr.write(f"RLM: Monkeypatch failure: {e}\n")
try:
from rlm.core.rlm import RLM
except ImportError:
sys.stderr.write("Error: Could not import rlm.\n")
RLM = None
from rlm_mcp_server.ingest import read_paths
from rlm_mcp_server.validate import validate_request, validate_result
from rlm_mcp_server.hashing import sha256_text
from rlm_mcp_server.provider import resolve_provider_config, prepare_backend_kwargs
mcp = FastMCP("rlm")
def _markdown_safe(text: str) -> str:
"""
Sanitize text for safe inclusion in Markdown.
Replaces ` (U+0060) with ˋ (U+02CB) to prevent breaking code blocks.
"""
if not text or not isinstance(text, str):
return text
return text.replace("`", "\u02CB")
def _extract_first_json(text: str) -> Any:
"""Best-effort JSON extraction from model response."""
if not text:
raise ValueError("Empty response")
def _parse_any(s: str) -> Any:
s = s.strip()
try:
return json.loads(s)
except json.JSONDecodeError:
try:
import ast
res = ast.literal_eval(s)
json.dumps(res) # Verify JSON compatibility
return res
except (ValueError, SyntaxError, ImportError, TypeError):
raise
# 1. Markdown blocks
m = re.search(r"```(?:json|python)?\s*([\s\S]*?)```", text)
if m:
try:
return _parse_any(m.group(1))
except Exception:
pass
# 2. Bracket matching
s = text.strip()
start = None
for i, ch in enumerate(s):
if ch in "[{":
start = i
break
if start is None:
raise ValueError("No JSON opener found in response")
open_ch = s[start]
close_ch = "]" if open_ch == "[" else "}"
depth = 0
in_str = False
esc = False
for j in range(start, len(s)):
c = s[j]
if in_str:
if esc: esc = False
elif c == "\\": esc = True
elif c == '"' or (c == "'" and open_ch == "{"): in_str = False
else:
if c == '"' or (c == "'" and open_ch == "{"): in_str = True
elif c == open_ch: depth += 1
elif c == close_ch:
depth -= 1
if depth == 0:
payload = s[start:j+1]
try:
return _parse_any(payload)
except Exception as e:
raise ValueError(f"Invalid payload: {e}")
raise ValueError("Unterminated JSON payload")
def _resolve_rlm_response(result: Any) -> str:
"""Defensive unwrapping for RLM completion objects."""
raw = getattr(result, "response", "")
fa = getattr(result, "final_answer", None)
if fa: return str(fa)
if isinstance(raw, (tuple, list)) and len(raw) == 2:
return str(raw[0])
if isinstance(raw, str) and raw.strip() == "final_answer":
if fa is not None: return str(fa)
def _try_extract_var(var_name: str, text: str) -> Optional[str]:
if hasattr(result, var_name):
val = getattr(result, var_name)
if val is not None: return str(val)
for container in ["memory", "variables", "state"]:
cont = getattr(result, container, None)
if isinstance(cont, dict) and var_name in cont:
return str(cont[var_name])
patterns = [
fr"{var_name}\s*=\s*\"\"\"(?P<val>[\s\S]*)\"\"\"",
fr"{var_name}\s*=\s*'''(?P<val>[\s\S]*)\"\"\"",
fr"{var_name}\s*=\s*\"(?P<val>[\s\S]*?)\"",
fr"{var_name}\s*=\s*'(?P<val>[\s\S]*?)'"
]
for p in patterns:
m = re.search(p, text, re.MULTILINE)
if m: return m.group("val").strip()
return None
if isinstance(raw, str):
patterns = [
r"FINAL(?:_VAR|_ANSWER)?\((?P<content>[\s\S]*)\)",
r"```FINAL\s*\n?(?P<content>[\s\S]*)```",
r"(?:^|\n)FINAL(?:_ANSWER)?:\s*(?P<content>[\s\S]*)$"
]
for pattern in patterns:
matches = list(re.finditer(pattern, raw, re.IGNORECASE))
if matches:
parts = []
for m in matches:
content = m.group("content").strip()
for q in ['"""', "'''", '"', "'"]:
if content.startswith(q) and content.endswith(q):
content = content[len(q):-len(q)].strip()
break
matched_str = m.group(0).upper()
if "FINAL_VAR" in matched_str or re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", content):
val = _try_extract_var(content, raw)
if val is not None: content = val
elif "FINAL_VAR" in matched_str: continue
if content and content not in parts: parts.append(content)
if parts: return "\n\n".join(parts)
if isinstance(raw, str):
val = _try_extract_var("final_answer", raw)
if val is not None: return val
repl_error_match = re.search(r"Error: Variable '(?P<leaked>.*)' not found", raw)
if repl_error_match:
leaked = repl_error_match.group("leaked").strip()
if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", leaked):
leaked = re.sub(r'^f?["\']', '', leaked).rstrip('"\']')
return leaked
cleaned = raw
for pattern in [
r"FINAL(?:_VAR|_ANSWER)?\([\s\S]*?\)",
r"```FINAL\s*[\s\S]*?```",
r"(?:^|\n)FINAL(?:_ANSWER)?:\s*.*?$",
r"```repl[\s\S]*?```"
]:
cleaned = re.sub(pattern, "", cleaned, flags=re.MULTILINE | re.IGNORECASE).strip()
if cleaned: return cleaned
return str(raw)
def _normalize_obj(arg: Any) -> Dict[str, Any]:
if not arg: return {}
if isinstance(arg, dict): return arg
if isinstance(arg, str):
try:
parsed = json.loads(arg)
if isinstance(parsed, dict): return parsed
except Exception: pass
return {}
@mcp.tool()
async def solve(
query: str,
text: Optional[str] = None,
globs: Optional[list[str]] = None,
ingest: Optional[Dict[str, Any]] = None,
rlm: Optional[Dict[str, Any]] = None,
rlm_config: Optional[Dict[str, Any]] = None,
provider: Optional[Dict[str, Any]] = None,
provider_preset: str = "ollama_local",
output: Optional[Dict[str, Any]] = None,
model_name: Optional[str] = None,
backend: Optional[str] = None,
other_model_name: Optional[str] = None,
environment: str = "docker",
max_iterations: int = 12,
temperature: float = 0.2,
timeout_sec: int = 300,
) -> CallToolResult:
"""Run Recursive Language Model (RLM) reasoning over context."""
n_rlm = _normalize_obj(rlm or rlm_config)
n_ingest = _normalize_obj(ingest)
n_provider = _normalize_obj(provider)
n_output = _normalize_obj(output)
config = {
"backend": backend or "openai",
"model_name": model_name,
"other_model_name": other_model_name,
"environment": environment,
"max_iterations": max_iterations,
"temperature": temperature,
"timeout_sec": timeout_sec
}
config.update(n_rlm)
if not model_name and config.get("model_name"):
model_name = config["model_name"]
request_data = {"query": query, "rlm": config}
if text is not None: request_data["text"] = text
if globs is not None: request_data["globs"] = globs
if n_ingest: request_data["ingest"] = n_ingest
if n_output: request_data["output"] = n_output
if n_provider: request_data["provider"] = n_provider
if provider_preset: request_data["provider_preset"] = provider_preset
if provider_preset == "ollama_local":
if os.environ.get("OPENROUTER_API_KEY") and not os.environ.get("OLLAMA_API_KEY"):
provider_preset = "openrouter"
elif os.environ.get("OPENAI_API_KEY") and not os.environ.get("OLLAMA_API_KEY"):
provider_preset = "openai"
try:
validate_request(request_data)
resolved_provider = resolve_provider_config(provider_preset, n_provider)
except Exception as e:
return CallToolResult(isError=True, content=[TextContent(type="text", text=str(e))])
repo_root = Path(os.getcwd()).resolve()
try:
if globs:
context_payload = read_paths(
repo_root, globs,
max_file_bytes=n_ingest.get("max_file_bytes", 200_000),
max_total_bytes=n_ingest.get("max_total_bytes", 2_000_000),
include_extensions=n_ingest.get("include_extensions"),
exclude_globs=n_ingest.get("exclude_globs")
)
else:
context_payload = text or ""
except Exception as e:
return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Ingest Error: {str(e)}")] )
if RLM is None:
return CallToolResult(isError=True, content=[TextContent(type="text", text="RLM library not found.")] )
class MemoryLogger:
def __init__(self): self.iterations = []
def log_metadata(self, metadata): pass
def log(self, iteration): self.iterations.append(iteration)
actual_timeout = config.get("timeout_sec", timeout_sec)
try:
backend_type = config.get("backend", backend or "openai")
if resolved_provider["mode"] == "openai_compatible":
backend_type = "openai"
primary_model = config.get("model_name") or model_name
if not primary_model:
env_default = os.environ.get("RLM_DEFAULT_MODEL")
if env_default: primary_model = env_default
elif provider_preset == "openrouter": primary_model = "google/gemini-2.0-flash-001"
elif resolved_provider["mode"] == "openai_cloud": primary_model = "gpt-4o-mini"
else: primary_model = "qwen-2.5-coder-32b-instruct"
backend_kwargs = prepare_backend_kwargs(
resolved_provider, model_name=primary_model,
temperature=config.get("temperature", temperature),
extra_kwargs=config.get("backend_kwargs")
)
other_name = config.get("other_model_name") or other_model_name
if not other_name:
other_name = os.environ.get("RLM_DEFAULT_RECURSION_MODEL") or primary_model
other_backends = [backend_type]
other_backend_kwargs = [prepare_backend_kwargs(
resolved_provider, model_name=other_name,
temperature=0.0, extra_kwargs=config.get("other_backend_kwargs")
)]
mem_logger = MemoryLogger()
rlm_inst = RLM(
backend=backend_type, backend_kwargs=backend_kwargs,
other_backends=other_backends, other_backend_kwargs=other_backend_kwargs,
environment=config.get("environment", environment),
max_iterations=config.get("max_iterations", max_iterations),
logger=mem_logger, verbose=False,
)
loop = asyncio.get_running_loop()
result = await asyncio.wait_for(
loop.run_in_executor(None, rlm_inst.completion, context_payload, query),
timeout=actual_timeout
)
raw_response = _resolve_rlm_response(result)
except Exception as e:
return CallToolResult(isError=True, content=[TextContent(type="text", text=str(e))])
output_data = {
"version": "rlm.solve.result.v1",
"status": "ok",
"answer": str(raw_response),
"answer_json": None,
"provenance": {
"backend": backend_type, "model_name": primary_model,
"other_model_name": other_name, "environment": config.get("environment", environment)
},
"metrics": {"wall_time_sec": getattr(result, "execution_time", 0.0), "iterations": 0, "recursive_calls": 0}
}
usage_summary = getattr(result, "usage_summary", None)
if usage_summary:
input_tokens = output_tokens = total_calls = 0
summaries = getattr(usage_summary, "model_usage_summaries", {})
for summary in summaries.values():
input_tokens += getattr(summary, "total_input_tokens", 0)
output_tokens += getattr(summary, "total_output_tokens", 0)
total_calls += getattr(summary, "total_calls", 0)
output_data["metrics"]["token_usage"] = {"input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens}
output_data["metrics"]["recursive_calls"] = total_calls
peak_input_tokens = 0
# Also check the root call tokens
usage_summary = getattr(result, "usage_summary", None)
if usage_summary:
summaries = getattr(usage_summary, "model_usage_summaries", {})
for summary in summaries.values():
# This is aggregated, so not ideal for 'peak', but better than 0
# if we can't find individual calls.
# However, we really want the peak of a SINGLE call.
pass
for it in mem_logger.iterations:
# Check root call tokens for this iteration if available
# (The first call in RLM is often not in 'code_blocks')
pass
for it in mem_logger.iterations:
for block in getattr(it, "code_blocks", []):
res = getattr(block, "result", None)
if not res: continue
# If the block itself has usage (some RLM versions)
if hasattr(res, "usage_summary"):
sums = getattr(res.usage_summary, "model_usage_summaries", {})
for s in sums.values():
peak_input_tokens = max(peak_input_tokens, getattr(s, "total_input_tokens", 0))
for call in getattr(res, "rlm_calls", []):
call_usage = getattr(call, "usage_summary", None)
if call_usage:
sums = getattr(call_usage, "model_usage_summaries", {})
for s in sums.values():
peak_input_tokens = max(peak_input_tokens, getattr(s, "total_input_tokens", 0))
# If we still have 0 but we have total input tokens, the root call was likely the peak
if peak_input_tokens == 0:
total_in = output_data["metrics"].get("token_usage", {}).get("input_tokens", 0)
num_calls = output_data["metrics"].get("recursive_calls", 1)
if num_calls > 0:
# Fallback: assume root call was at least total/calls (very conservative)
# or just use total if num_calls is 1
if num_calls == 1:
peak_input_tokens = total_in
else:
# We can't know for sure, but 0 is definitely wrong if total > 0
peak_input_tokens = total_in // num_calls
output_data["metrics"]["peak_input_tokens"] = peak_input_tokens
try:
parsed = _extract_first_json(raw_response)
output_data["answer_json"] = parsed
if isinstance(parsed, dict):
if "structured" in parsed: output_data["structured"] = parsed["structured"]
if "evidence" in parsed and isinstance(parsed["evidence"], list):
validated_evidence = []
source_context = context_payload if isinstance(context_payload, dict) else {}
for item in parsed["evidence"]:
if not isinstance(item, dict): continue
path_key = item.get("path")
if path_key and path_key in source_context:
lines = source_context[path_key].splitlines()
start = max(1, item.get("start_line", 1))
end = min(len(lines), max(start, item.get("end_line", start)))
item["excerpt"] = "\n".join(lines[start-1:end])
item["excerpt_hash"] = sha256_text(item.get("excerpt", ""))
if "excerpt" in item: item["excerpt"] = _markdown_safe(item["excerpt"])
validated_evidence.append(item)
output_data["evidence"] = validated_evidence
if "patch" in parsed: output_data["patch"] = parsed["patch"]
if "answer" in parsed and isinstance(parsed["answer"], str): output_data["answer"] = parsed["answer"]
except Exception: pass
try:
validate_result(output_data)
except Exception as e:
return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Result Validation Error: {str(e)}")] )
if n_output.get("include_metrics") or n_output.get("include_evidence") or n_output.get("include_patch"):
output_data["answer"] = _markdown_safe(output_data["answer"])
text_out = json.dumps(output_data, ensure_ascii=False, indent=2)
else:
text_out = _markdown_safe(output_data["answer"])
return CallToolResult(content=[TextContent(type="text", text=text_out)], meta={"structured_content": output_data})
if __name__ == "__main__":
mcp.run()