"""
Evaluation script for semantic search quality.
Metrics used:
- Precision@K: Fraction of relevant results among returned items.
- Recall@K: Fraction of all relevant items that were found.
- F1@K: Harmonic mean of Precision and Recall.
- Success Rate@K: Fraction of queries where at least one relevant snippet was returned.
NEW Granular Metrics:
- File Discovery Rate: % of expected files found (regardless of content completeness)
- Substring Coverage: Average % of required substrings found per file
- Partial Match Rate: % of results that found the file but not all required content
Stability Metrics (for multiple runs):
- Standard deviation of metrics across runs
- Coefficient of variation (CV = std/mean)
- Min/Max values across runs
These help distinguish between:
- "Didn't find the file at all" vs "Found file but incomplete content"
- Partial credit for finding some of the required information
- Consistency of results across multiple runs
"""
from __future__ import annotations
import json
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from src.core import run_semantic_search
from .metrics import (
AggregatedMetrics,
AggregatedStabilityMetrics,
FileMatchDetail,
MetricStability,
QueryMetrics,
QueryStabilityMetrics,
TokenStats,
aggregate_metrics,
aggregate_query_runs,
aggregate_stability_metrics,
calculate_f1,
calculate_precision,
calculate_recall,
)
DATASET_PATH = Path("data/dataset.jsonl")
DATASET_EASY_PATH = Path("data/dataset_easy.jsonl")
REPOS_BASE_PATH = Path("data/issues")
K = 10 # top-K for metrics calculation
NUM_RUNS = 10 # number of parallel runs per test
MAX_WORKERS = 10 # max parallel workers for running tests
@dataclass
class ExpectedItem:
"""A single expected item in the ground truth."""
file_path: str
must_include_substrings: List[str]
@dataclass
class Example:
"""A single evaluation example."""
id: str
description: str
repo_path: str
query: str
path: Optional[str]
expected_items: List[ExpectedItem]
dataset_name: str = "unknown" # Track which dataset this example came from
def load_dataset(path: Path, dataset_name: Optional[str] = None) -> List[Example]:
"""Load evaluation dataset from JSONL file."""
if dataset_name is None:
dataset_name = path.stem
examples: List[Example] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
raw = json.loads(line)
expected_items = [
ExpectedItem(
file_path=item["file_path"],
must_include_substrings=item["must_include_substrings"],
)
for item in raw["expected_items"]
]
examples.append(
Example(
id=raw["id"],
description=raw.get("description", ""),
repo_path=raw["repo_path"],
query=raw["query"],
path=raw.get("path"),
expected_items=expected_items,
dataset_name=dataset_name,
)
)
return examples
def load_all_datasets(
dataset_path: Optional[Path] = None,
dataset_easy_path: Optional[Path] = None,
) -> List[Example]:
"""
Load both main dataset and easy dataset.
Returns combined list of examples with dataset_name tracking.
"""
main_path = dataset_path or DATASET_PATH
easy_path = dataset_easy_path or DATASET_EASY_PATH
examples: List[Example] = []
# Load main dataset
if main_path.exists():
main_examples = load_dataset(main_path, dataset_name="main")
examples.extend(main_examples)
print(f"Loaded {len(main_examples)} examples from main dataset: {main_path}")
else:
print(f"Warning: Main dataset not found at {main_path}")
# Load easy dataset
if easy_path.exists():
easy_examples = load_dataset(easy_path, dataset_name="easy")
examples.extend(easy_examples)
print(f"Loaded {len(easy_examples)} examples from easy dataset: {easy_path}")
else:
print(f"Warning: Easy dataset not found at {easy_path}")
return examples
def is_match(pred_item: Dict[str, str], gold_item: ExpectedItem) -> bool:
"""
Check if a predicted item FULLY matches a gold item.
Requires: file path match + ALL required substrings present.
"""
if pred_item.get("file_path") != gold_item.file_path:
return False
content = pred_item.get("content", "")
return all(substr in content for substr in gold_item.must_include_substrings)
def analyze_file_match(
pred_items: List[Dict[str, str]],
gold_item: ExpectedItem
) -> FileMatchDetail:
"""
Analyze how well we matched a single expected file.
Returns detailed info about what was found vs missing.
"""
# Find all predictions for this file
matching_preds = [
p for p in pred_items
if p.get("file_path") == gold_item.file_path
]
if not matching_preds:
# File not found at all
return FileMatchDetail(
file_path=gold_item.file_path,
found=False,
full_match=False,
substrings_required=len(gold_item.must_include_substrings),
substrings_found=0,
missing_substrings=list(gold_item.must_include_substrings),
)
# Combine content from all snippets of this file
combined_content = "\n".join(p.get("content", "") for p in matching_preds)
# Check each required substring
found_substrings = []
missing_substrings = []
for substr in gold_item.must_include_substrings:
if substr in combined_content:
found_substrings.append(substr)
else:
missing_substrings.append(substr)
return FileMatchDetail(
file_path=gold_item.file_path,
found=True,
full_match=len(missing_substrings) == 0,
substrings_required=len(gold_item.must_include_substrings),
substrings_found=len(found_substrings),
missing_substrings=missing_substrings,
)
def evaluate_example(
example: Example,
repos_base_path: Path = REPOS_BASE_PATH,
searcher_type: Optional[str] = None,
) -> QueryMetrics:
"""Evaluate a single example and return detailed metrics."""
import os
from src.core import SearcherType, _SEARCHER_INSTANCES
repo_path = repos_base_path / example.repo_path / "repo"
# Determine searcher type
s_type = None
type_str = searcher_type or os.getenv("SEMANTIC_SEARCHER", "")
if type_str:
try:
s_type = SearcherType(type_str)
# Clear cache to ensure fresh searcher
_SEARCHER_INSTANCES.clear()
except ValueError:
pass
# Run the search
result = run_semantic_search(
query=example.query,
repo_path=str(repo_path),
path=example.path,
searcher_type=s_type,
)
predicted = result.get("items", [])[:K]
gold = example.expected_items
execution_time_ms = result.get("execution_time_ms", 0.0) or 0.0
# Track which gold items have been matched
used_gold = [False] * len(gold)
# Count matches
true_positives = 0
false_positives = 0
# Token statistics
token_stats = TokenStats()
# Match each predicted snippet to gold items
for pred_item in predicted:
content = pred_item.get("content", "")
token_stats.add_snippet(content)
matched = False
for i, gold_item in enumerate(gold):
if used_gold[i]:
continue
if is_match(pred_item, gold_item):
used_gold[i] = True
true_positives += 1
matched = True
break
if not matched:
false_positives += 1
false_negatives = used_gold.count(False)
# Calculate standard metrics
precision = calculate_precision(true_positives, false_positives)
recall = calculate_recall(true_positives, false_negatives)
f1 = calculate_f1(precision, recall)
success = true_positives > 0
# NEW: Calculate granular metrics
file_details = []
files_found = 0
files_fully_matched = 0
files_partially_matched = 0
total_substrings_required = 0
total_substrings_found = 0
for gold_item in gold:
detail = analyze_file_match(predicted, gold_item)
file_details.append(detail)
total_substrings_required += detail.substrings_required
total_substrings_found += detail.substrings_found
if detail.found:
files_found += 1
if detail.full_match:
files_fully_matched += 1
else:
files_partially_matched += 1
return QueryMetrics(
query_id=example.id,
precision=precision,
recall=recall,
f1=f1,
success=success,
token_stats=token_stats,
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
execution_time_ms=execution_time_ms,
# NEW granular metrics
files_expected=len(gold),
files_found=files_found,
files_fully_matched=files_fully_matched,
files_partially_matched=files_partially_matched,
total_substrings_required=total_substrings_required,
total_substrings_found=total_substrings_found,
file_details=file_details,
)
def evaluate_example_single_run(
example: Example,
run_id: int,
repos_base_path: Path = REPOS_BASE_PATH,
searcher_type: Optional[str] = None,
) -> Tuple[int, QueryMetrics]:
"""Evaluate a single example once, returning run_id for ordering."""
qm = evaluate_example(example, repos_base_path, searcher_type)
return run_id, qm
def evaluate_example_multiple_runs(
example: Example,
repos_base_path: Path = REPOS_BASE_PATH,
searcher_type: Optional[str] = None,
num_runs: int = NUM_RUNS,
max_workers: int = MAX_WORKERS,
) -> QueryStabilityMetrics:
"""
Evaluate a single example multiple times in parallel.
Returns stability metrics across all runs.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
run_results: List[QueryMetrics] = [None] * num_runs # type: ignore
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
evaluate_example_single_run,
example,
run_id,
repos_base_path,
searcher_type,
): run_id
for run_id in range(num_runs)
}
for future in as_completed(futures):
run_id, qm = future.result()
run_results[run_id] = qm
return aggregate_query_runs(example.id, run_results)
def _format_stability(stab: MetricStability) -> str:
"""Format a MetricStability for compact display."""
stability_icon = "✓" if stab.is_stable else "⚠"
return f"{stab.mean:.3f}±{stab.std:.3f} {stability_icon}"
def print_query_result(qm: QueryMetrics, verbose: bool = False) -> None:
"""Print results for a single query (legacy single-run mode)."""
status = "✓" if qm.success else "✗"
latency_status = "⚡" if qm.meets_latency_target else "🐢"
# Basic info
print(
f"[{qm.query_id}] {status} "
f"P={qm.precision:.3f} R={qm.recall:.3f} F1={qm.f1:.3f} "
f"| TP={qm.true_positives} FP={qm.false_positives} FN={qm.false_negatives} "
f"| {latency_status} {qm.execution_time_seconds:.1f}s"
)
# NEW: Granular info
file_icon = "📁" if qm.files_found > 0 else "❌"
print(
f" {file_icon} Files: {qm.files_found}/{qm.files_expected} found "
f"({qm.files_fully_matched} full, {qm.files_partially_matched} partial) "
f"| Substrings: {qm.total_substrings_found}/{qm.total_substrings_required} "
f"({qm.substring_coverage:.0%})"
)
if verbose:
print(
f" Tokens: {qm.token_stats.total_chars} chars, "
f"{qm.token_stats.total_lines} lines, "
f"{qm.token_stats.snippet_count} snippets"
)
# Show per-file details
for detail in qm.file_details:
if detail.full_match:
icon = "✅"
elif detail.found:
icon = "🟡" # Partial match
else:
icon = "❌"
print(
f" {icon} {detail.file_path}: "
f"{detail.substrings_found}/{detail.substrings_required} substrings"
)
if detail.missing_substrings and not detail.full_match:
missing_preview = detail.missing_substrings[:2]
if len(detail.missing_substrings) > 2:
missing_preview.append("...")
print(f" Missing: {missing_preview}")
def print_stability_query_result(qs: QueryStabilityMetrics, verbose: bool = False) -> None:
"""Print results for a single query with stability metrics."""
stability_icon = "🔒" if qs.is_stable else "⚠️"
# Basic info with stability
print(
f"[{qs.query_id}] {stability_icon} "
f"P={_format_stability(qs.precision)} "
f"R={_format_stability(qs.recall)} "
f"F1={_format_stability(qs.f1)}"
)
# Success rate and timing
print(
f" Success: {qs.success_rate.mean:.0%} | "
f"Time: {qs.execution_time.mean/1000:.1f}s±{qs.execution_time.std/1000:.1f}s | "
f"Stability: {qs.overall_stability_score:.1%}"
)
# Granular metrics
print(
f" 📁 FileDiscovery: {_format_stability(qs.file_discovery_rate)} | "
f"SubstringCoverage: {_format_stability(qs.substring_coverage)}"
)
if verbose:
print(f" Runs: {qs.num_runs}")
# Show individual run results
for i, qm in enumerate(qs.run_results):
status = "✓" if qm.success else "✗"
print(
f" Run {i+1}: {status} "
f"P={qm.precision:.3f} R={qm.recall:.3f} F1={qm.f1:.3f} "
f"| {qm.execution_time_seconds:.1f}s"
)
def print_summary(agg: AggregatedMetrics) -> None:
"""Print aggregated metrics summary."""
print("\n" + "=" * 70)
print("EVALUATION SUMMARY")
print("=" * 70)
print(f"\nQueries evaluated: {agg.num_queries}")
print("\n--- Standard Metrics (Strict Matching) ---")
print(f" Precision@{K}: {agg.macro_precision:.4f}")
print(f" Recall@{K}: {agg.macro_recall:.4f}")
print(f" F1@{K}: {agg.macro_f1:.4f}")
print(f" Success Rate@{K}: {agg.success_rate:.4f}")
print("\n--- Granular Metrics (Partial Credit) ---")
print(f" File Discovery Rate: {agg.avg_file_discovery_rate:.4f} "
f"({agg.total_files_found}/{agg.total_files_expected} files found)")
print(f" Substring Coverage: {agg.avg_substring_coverage:.4f} "
f"(avg % of required substrings found)")
print(f" Full Match Rate: {agg.total_files_fully_matched}/{agg.total_files_expected} "
f"({agg.total_files_fully_matched/max(1,agg.total_files_expected):.1%} of expected files)")
print(f" Partial Match Rate: {agg.total_files_partially_matched}/{agg.total_files_expected} "
f"({agg.avg_partial_match_rate:.1%} - found file but incomplete)")
print("\n--- Latency Statistics (target: <10s) ---")
print(f" Total time: {agg.total_time_seconds:.1f}s")
print(f" Avg time per query: {agg.avg_time_seconds:.1f}s")
print(f" Min time: {agg.min_time_ms / 1000:.1f}s")
print(f" Max time: {agg.max_time_ms / 1000:.1f}s")
print(f" Queries under 10s: {agg.queries_meeting_latency_target}/{agg.num_queries} ({agg.latency_target_rate:.1%})")
print("\n--- Token Statistics ---")
print(f" Total chars returned: {agg.total_token_stats.total_chars:,}")
print(f" Total lines returned: {agg.total_token_stats.total_lines:,}")
print(f" Total snippets returned: {agg.total_token_stats.snippet_count}")
print(f" Avg chars per query: {agg.avg_chars_per_query:,.1f}")
print(f" Avg lines per query: {agg.avg_lines_per_query:.1f}")
print(f" Avg snippets per query: {agg.avg_snippets_per_query:.1f}")
print(f" Avg chars per snippet: {agg.total_token_stats.avg_chars_per_snippet:,.1f}")
print(f" Avg lines per snippet: {agg.total_token_stats.avg_lines_per_snippet:.1f}")
# Quality assessment
print("\n--- Quality Assessment ---")
if agg.is_perfect():
print(" ✓ PERFECT: All queries have P=1.0 and R=1.0")
else:
if agg.macro_precision < 1.0:
print(f" ✗ Precision < 1.0: Some irrelevant snippets returned")
if agg.macro_recall < 1.0:
print(f" ✗ Recall < 1.0: Some relevant snippets missed")
# Interpretation help
print("\n--- Interpretation Guide ---")
if agg.avg_file_discovery_rate > agg.success_rate:
gap = agg.avg_file_discovery_rate - agg.success_rate
print(f" 💡 File discovery ({agg.avg_file_discovery_rate:.1%}) > Success rate ({agg.success_rate:.1%})")
print(f" → Finding correct files but snippets too small to contain all substrings")
print(f" → Consider larger context or merging snippets from same file")
if agg.avg_substring_coverage > 0.5 and agg.success_rate < 0.5:
print(f" 💡 Good substring coverage ({agg.avg_substring_coverage:.1%}) but low success ({agg.success_rate:.1%})")
print(f" → Content is fragmented across multiple small snippets")
print(f" → Consider returning larger file chunks")
if agg.total_files_partially_matched > agg.total_files_fully_matched:
print(f" 💡 More partial matches ({agg.total_files_partially_matched}) than full ({agg.total_files_fully_matched})")
print(f" → Snippets need to be larger to capture all required content")
# Warnings
print("\n--- Warnings ---")
warnings_printed = False
if agg.latency_target_rate < 1.0:
print(f" ⚠ Latency target missed: {100 - agg.latency_target_rate * 100:.0f}% of queries took >10s")
warnings_printed = True
if agg.total_token_stats.avg_lines_per_snippet > 200:
print(f" ⚠ Very large snippets ({agg.total_token_stats.avg_lines_per_snippet:.0f} lines avg)")
print(" May be returning too much irrelevant code.")
warnings_printed = True
if agg.success_rate < 0.5:
print(f" ⚠ Low success rate ({agg.success_rate:.1%})")
print(" Many queries return no fully matching results.")
warnings_printed = True
if not warnings_printed:
print(" ✓ No warnings")
def print_stability_summary(agg: AggregatedStabilityMetrics) -> None:
"""Print aggregated stability metrics summary."""
print("\n" + "=" * 70)
print(f"STABILITY EVALUATION SUMMARY ({agg.num_runs_per_query} runs per query)")
print("=" * 70)
print(f"\nQueries evaluated: {agg.num_queries}")
print(f"Total runs: {agg.num_queries * agg.num_runs_per_query}")
print("\n--- Averaged Metrics (mean ± std) ---")
print(f" Precision@{K}: {agg.precision}")
print(f" Recall@{K}: {agg.recall}")
print(f" F1@{K}: {agg.f1}")
print(f" Success Rate@{K}: {agg.success_rate}")
print(f" File Discovery Rate: {agg.file_discovery_rate}")
print(f" Substring Coverage: {agg.substring_coverage}")
print("\n--- Latency Statistics ---")
print(f" Avg time per query: {agg.execution_time.mean/1000:.2f}s ± {agg.execution_time.std/1000:.2f}s")
print(f" Min avg time: {agg.execution_time.min_val/1000:.2f}s")
print(f" Max avg time: {agg.execution_time.max_val/1000:.2f}s")
print("\n--- Stability Assessment ---")
print(f" Overall Stability Score: {agg.avg_stability_score:.1%}")
print(f" Stable Queries: {agg.stable_queries_count}/{agg.num_queries} "
f"({agg.stable_queries_count/max(1,agg.num_queries):.1%})")
# Per-metric stability
print("\n--- Per-Metric Stability (CV = Coefficient of Variation) ---")
metrics_info = [
("Precision", agg.precision),
("Recall", agg.recall),
("F1", agg.f1),
("Success Rate", agg.success_rate),
("File Discovery", agg.file_discovery_rate),
("Substring Coverage", agg.substring_coverage),
]
for name, metric in metrics_info:
stability_icon = "✓" if metric.is_stable else "⚠"
print(f" {stability_icon} {name:20s} CV={metric.cv:6.1%} "
f"range=[{metric.min_val:.3f}, {metric.max_val:.3f}]")
# Warnings
print("\n--- Stability Warnings ---")
warnings_printed = False
unstable_metrics = [(n, m) for n, m in metrics_info if not m.is_stable]
if unstable_metrics:
print(f" ⚠ High variance detected in {len(unstable_metrics)} metric(s):")
for name, metric in unstable_metrics:
print(f" - {name}: CV={metric.cv:.1%} (>10% threshold)")
warnings_printed = True
if agg.stable_queries_count < agg.num_queries * 0.8:
unstable_count = agg.num_queries - agg.stable_queries_count
print(f" ⚠ {unstable_count} queries show unstable results across runs")
print(" Consider investigating these queries for non-deterministic behavior")
warnings_printed = True
if agg.execution_time.cv > 0.5:
print(f" ⚠ High latency variance: CV={agg.execution_time.cv:.1%}")
print(" Execution time varies significantly between runs")
warnings_printed = True
if not warnings_printed:
print(" ✓ Results are stable across all runs")
# Interpretation
print("\n--- Interpretation Guide ---")
if agg.avg_stability_score > 0.95:
print(" 🔒 Excellent stability: Results are highly reproducible")
elif agg.avg_stability_score > 0.85:
print(" ✓ Good stability: Results are mostly reproducible with minor variance")
elif agg.avg_stability_score > 0.7:
print(" ⚠ Moderate stability: Some variance observed, results may differ between runs")
else:
print(" ❌ Poor stability: Results vary significantly between runs")
print(" This may indicate non-deterministic search behavior or system issues")
def export_results_json(agg: AggregatedMetrics, output_path: Path) -> None:
"""Export detailed results to JSON file."""
results = {
"summary": {
"num_queries": agg.num_queries,
"macro_precision": agg.macro_precision,
"macro_recall": agg.macro_recall,
"macro_f1": agg.macro_f1,
"success_rate": agg.success_rate,
"is_perfect": agg.is_perfect(),
# NEW granular metrics
"avg_file_discovery_rate": agg.avg_file_discovery_rate,
"avg_substring_coverage": agg.avg_substring_coverage,
"total_files_expected": agg.total_files_expected,
"total_files_found": agg.total_files_found,
"total_files_fully_matched": agg.total_files_fully_matched,
"total_files_partially_matched": agg.total_files_partially_matched,
},
"latency_stats": {
"total_time_ms": agg.total_time_ms,
"avg_time_ms": agg.avg_time_ms,
"min_time_ms": agg.min_time_ms,
"max_time_ms": agg.max_time_ms,
"queries_under_10s": agg.queries_meeting_latency_target,
"latency_target_rate": agg.latency_target_rate,
},
"token_stats": {
"total_chars": agg.total_token_stats.total_chars,
"total_lines": agg.total_token_stats.total_lines,
"total_snippets": agg.total_token_stats.snippet_count,
"avg_chars_per_query": agg.avg_chars_per_query,
"avg_lines_per_query": agg.avg_lines_per_query,
"avg_snippets_per_query": agg.avg_snippets_per_query,
},
"queries": [
{
"id": qm.query_id,
"precision": qm.precision,
"recall": qm.recall,
"f1": qm.f1,
"success": qm.success,
"true_positives": qm.true_positives,
"false_positives": qm.false_positives,
"false_negatives": qm.false_negatives,
"execution_time_ms": qm.execution_time_ms,
# NEW granular
"files_expected": qm.files_expected,
"files_found": qm.files_found,
"files_fully_matched": qm.files_fully_matched,
"files_partially_matched": qm.files_partially_matched,
"file_discovery_rate": qm.file_discovery_rate,
"substring_coverage": qm.substring_coverage,
"file_details": [
{
"file_path": fd.file_path,
"found": fd.found,
"full_match": fd.full_match,
"substrings_found": fd.substrings_found,
"substrings_required": fd.substrings_required,
"substring_coverage": fd.substring_coverage,
"missing_substrings": fd.missing_substrings,
}
for fd in qm.file_details
],
}
for qm in agg.query_results
],
}
with output_path.open("w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"\nResults exported to: {output_path}")
def _metric_stability_to_dict(ms: MetricStability) -> dict:
"""Convert MetricStability to dict for JSON export."""
return {
"mean": ms.mean,
"std": ms.std,
"min": ms.min_val,
"max": ms.max_val,
"cv": ms.cv,
"is_stable": ms.is_stable,
}
def export_stability_results_json(agg: AggregatedStabilityMetrics, output_path: Path) -> None:
"""Export stability results to JSON file."""
results = {
"summary": {
"num_queries": agg.num_queries,
"num_runs_per_query": agg.num_runs_per_query,
"total_runs": agg.num_queries * agg.num_runs_per_query,
"avg_stability_score": agg.avg_stability_score,
"stable_queries_count": agg.stable_queries_count,
"stable_queries_rate": agg.stable_queries_count / max(1, agg.num_queries),
},
"metrics": {
"precision": _metric_stability_to_dict(agg.precision),
"recall": _metric_stability_to_dict(agg.recall),
"f1": _metric_stability_to_dict(agg.f1),
"success_rate": _metric_stability_to_dict(agg.success_rate),
"file_discovery_rate": _metric_stability_to_dict(agg.file_discovery_rate),
"substring_coverage": _metric_stability_to_dict(agg.substring_coverage),
"execution_time_ms": _metric_stability_to_dict(agg.execution_time),
},
"queries": [
{
"id": qs.query_id,
"num_runs": qs.num_runs,
"overall_stability_score": qs.overall_stability_score,
"is_stable": qs.is_stable,
"metrics": {
"precision": _metric_stability_to_dict(qs.precision),
"recall": _metric_stability_to_dict(qs.recall),
"f1": _metric_stability_to_dict(qs.f1),
"success_rate": _metric_stability_to_dict(qs.success_rate),
"file_discovery_rate": _metric_stability_to_dict(qs.file_discovery_rate),
"substring_coverage": _metric_stability_to_dict(qs.substring_coverage),
"execution_time_ms": _metric_stability_to_dict(qs.execution_time),
},
"runs": [
{
"run_id": i,
"precision": qm.precision,
"recall": qm.recall,
"f1": qm.f1,
"success": qm.success,
"execution_time_ms": qm.execution_time_ms,
"file_discovery_rate": qm.file_discovery_rate,
"substring_coverage": qm.substring_coverage,
}
for i, qm in enumerate(qs.run_results)
],
}
for qs in agg.query_stability_results
],
}
with output_path.open("w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"\nStability results exported to: {output_path}")
def _split_metrics_by_dataset(
query_metrics: List[QueryMetrics],
examples: List[Example],
) -> Dict[str, Tuple[List[QueryMetrics], List[Example]]]:
"""Split metrics and examples by dataset name."""
dataset_map: Dict[str, Tuple[List[QueryMetrics], List[Example]]] = {}
for qm, ex in zip(query_metrics, examples):
dataset_name = ex.dataset_name
if dataset_name not in dataset_map:
dataset_map[dataset_name] = ([], [])
dataset_map[dataset_name][0].append(qm)
dataset_map[dataset_name][1].append(ex)
return dataset_map
def _print_dataset_summary(
dataset_name: str,
agg: AggregatedMetrics,
examples: List[Example],
) -> None:
"""Print summary for a specific dataset."""
print("\n" + "=" * 70)
print(f"DATASET: {dataset_name.upper()} ({len(examples)} examples)")
print("=" * 70)
print_summary(agg)
def main(
dataset_path: Optional[Path] = None,
dataset_easy_path: Optional[Path] = None,
repos_base_path: Optional[Path] = None,
output_json: Optional[Path] = None,
verbose: bool = False,
use_both_datasets: bool = True,
) -> AggregatedMetrics:
"""
Run evaluation (single run per query) and return aggregated metrics.
Args:
dataset_path: Path to main dataset (default: data/dataset.jsonl)
dataset_easy_path: Path to easy dataset (default: data/dataset_easy.jsonl)
repos_base_path: Base path for repositories
output_json: Path to export JSON results
verbose: Print detailed per-query statistics
use_both_datasets: If True, load both main and easy datasets (default: True)
"""
repos_path = repos_base_path or REPOS_BASE_PATH
if use_both_datasets:
examples = load_all_datasets(dataset_path, dataset_easy_path)
else:
path = dataset_path or DATASET_PATH
examples = load_dataset(path)
print(f"\nRunning evaluation on {len(examples)} total examples...")
print(f"Repositories base path: {repos_path.absolute()}\n")
query_metrics: List[QueryMetrics] = []
for ex in examples:
qm = evaluate_example(ex, repos_base_path=repos_path)
query_metrics.append(qm)
dataset_prefix = f"[{ex.dataset_name}]" if use_both_datasets else ""
print(f"{dataset_prefix} ", end="")
print_query_result(qm, verbose=verbose)
# Aggregate all metrics
aggregated = aggregate_metrics(query_metrics)
# Print overall summary
print_summary(aggregated)
# Print per-dataset summaries if using both datasets
if use_both_datasets:
dataset_map = _split_metrics_by_dataset(query_metrics, examples)
for dataset_name, (metrics, dataset_examples) in sorted(dataset_map.items()):
dataset_agg = aggregate_metrics(metrics)
_print_dataset_summary(dataset_name, dataset_agg, dataset_examples)
# Export if requested
if output_json:
export_results_json(aggregated, output_json)
return aggregated
def _split_stability_by_dataset(
query_stabilities: List[QueryStabilityMetrics],
examples: List[Example],
) -> Dict[str, Tuple[List[QueryStabilityMetrics], List[Example]]]:
"""Split stability metrics and examples by dataset name."""
dataset_map: Dict[str, Tuple[List[QueryStabilityMetrics], List[Example]]] = {}
for qs, ex in zip(query_stabilities, examples):
dataset_name = ex.dataset_name
if dataset_name not in dataset_map:
dataset_map[dataset_name] = ([], [])
dataset_map[dataset_name][0].append(qs)
dataset_map[dataset_name][1].append(ex)
return dataset_map
def _print_dataset_stability_summary(
dataset_name: str,
agg: AggregatedStabilityMetrics,
examples: List[Example],
) -> None:
"""Print stability summary for a specific dataset."""
print("\n" + "=" * 70)
print(f"DATASET: {dataset_name.upper()} ({len(examples)} examples)")
print("=" * 70)
print_stability_summary(agg)
def main_with_stability(
dataset_path: Optional[Path] = None,
dataset_easy_path: Optional[Path] = None,
repos_base_path: Optional[Path] = None,
output_json: Optional[Path] = None,
verbose: bool = False,
num_runs: int = NUM_RUNS,
max_workers: int = MAX_WORKERS,
use_both_datasets: bool = True,
) -> AggregatedStabilityMetrics:
"""
Run evaluation with multiple parallel runs per query.
Returns stability metrics showing consistency across runs.
Args:
dataset_path: Path to main dataset (default: data/dataset.jsonl)
dataset_easy_path: Path to easy dataset (default: data/dataset_easy.jsonl)
repos_base_path: Base path for repositories
output_json: Path to export JSON results
verbose: Print detailed per-query statistics
num_runs: Number of runs per query
max_workers: Maximum parallel workers
use_both_datasets: If True, load both main and easy datasets (default: True)
"""
repos_path = repos_base_path or REPOS_BASE_PATH
if use_both_datasets:
examples = load_all_datasets(dataset_path, dataset_easy_path)
else:
path = dataset_path or DATASET_PATH
examples = load_dataset(path)
print(f"\nRunning STABILITY evaluation on {len(examples)} total examples...")
print(f"Each query will be run {num_runs} times in parallel (max {max_workers} workers)")
print(f"Repositories base path: {repos_path.absolute()}\n")
query_stabilities: List[QueryStabilityMetrics] = []
for i, ex in enumerate(examples, 1):
dataset_prefix = f"[{ex.dataset_name}]" if use_both_datasets else ""
print(f"{dataset_prefix} [{i}/{len(examples)}] Evaluating query '{ex.id}' ({num_runs} runs)...")
qs = evaluate_example_multiple_runs(
example=ex,
repos_base_path=repos_path,
num_runs=num_runs,
max_workers=max_workers,
)
query_stabilities.append(qs)
print_stability_query_result(qs, verbose=verbose)
print()
# Aggregate all stability metrics
aggregated = aggregate_stability_metrics(query_stabilities, num_runs)
# Print overall summary
print_stability_summary(aggregated)
# Print per-dataset summaries if using both datasets
if use_both_datasets:
dataset_map = _split_stability_by_dataset(query_stabilities, examples)
for dataset_name, (stabilities, dataset_examples) in sorted(dataset_map.items()):
dataset_agg = aggregate_stability_metrics(stabilities, num_runs)
_print_dataset_stability_summary(dataset_name, dataset_agg, dataset_examples)
# Export if requested
if output_json:
export_stability_results_json(aggregated, output_json)
return aggregated
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Evaluate semantic search quality")
parser.add_argument(
"--dataset",
type=Path,
default=DATASET_PATH,
help="Path to main JSONL dataset file (default: data/dataset.jsonl)",
)
parser.add_argument(
"--dataset-easy",
type=Path,
default=DATASET_EASY_PATH,
help="Path to easy JSONL dataset file (default: data/dataset_easy.jsonl)",
)
parser.add_argument(
"--single-dataset",
action="store_true",
help="Use only the main dataset (disable automatic loading of easy dataset)",
)
parser.add_argument(
"--repos",
type=Path,
default=REPOS_BASE_PATH,
help="Base path for evaluation repositories",
)
parser.add_argument(
"--output",
type=Path,
default=None,
help="Path to export JSON results",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Print detailed per-query statistics",
)
parser.add_argument(
"--stability",
action="store_true",
help="Run each query multiple times to measure stability",
)
parser.add_argument(
"--runs",
type=int,
default=NUM_RUNS,
help=f"Number of runs per query in stability mode (default: {NUM_RUNS})",
)
parser.add_argument(
"--workers",
type=int,
default=MAX_WORKERS,
help=f"Maximum parallel workers for stability runs (default: {MAX_WORKERS})",
)
args = parser.parse_args()
use_both = not args.single_dataset
if args.stability:
result = main_with_stability(
dataset_path=args.dataset,
dataset_easy_path=args.dataset_easy,
repos_base_path=args.repos,
output_json=args.output,
verbose=args.verbose,
num_runs=args.runs,
max_workers=args.workers,
use_both_datasets=use_both,
)
# Exit with error code if low stability
if result.avg_stability_score < 0.5:
print("\n❌ Stability score below 50%, exiting with error")
sys.exit(1)
else:
result = main(
dataset_path=args.dataset,
dataset_easy_path=args.dataset_easy,
repos_base_path=args.repos,
output_json=args.output,
verbose=args.verbose,
use_both_datasets=use_both,
)
# Exit with error code if no queries succeeded
if result.success_rate == 0 and result.num_queries > 0:
sys.exit(1)