"""
Metrics for evaluating semantic search quality.
Key metrics:
- Precision@K: Fraction of returned results that are relevant
- Recall@K: Fraction of all relevant items that were returned
- F1@K: Harmonic mean of Precision and Recall
- Success@K: Whether at least one relevant item was returned
NEW - Granular metrics:
- File Discovery Rate: % of expected files found (regardless of content)
- Substring Coverage: Average % of required substrings found per file
- Partial Match Rate: % of results that found file but not all content
Stability metrics (for multiple runs):
- Standard deviation of metrics across runs
- Coefficient of variation (CV = std/mean)
- Min/Max values across runs
Token metrics (for monitoring "token bloat"):
- Total characters in returned snippets
- Total lines in returned snippets
- Average snippet length
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@dataclass
class TokenStats:
"""Statistics about token/character usage in returned snippets."""
total_chars: int = 0
total_lines: int = 0
snippet_count: int = 0
@property
def avg_chars_per_snippet(self) -> float:
"""Average characters per snippet."""
if self.snippet_count == 0:
return 0.0
return self.total_chars / self.snippet_count
@property
def avg_lines_per_snippet(self) -> float:
"""Average lines per snippet."""
if self.snippet_count == 0:
return 0.0
return self.total_lines / self.snippet_count
def add_snippet(self, content: str) -> None:
"""Add a snippet's statistics."""
self.total_chars += len(content)
self.total_lines += content.count('\n') + 1 if content else 0
self.snippet_count += 1
def __add__(self, other: "TokenStats") -> "TokenStats":
"""Combine two TokenStats objects."""
return TokenStats(
total_chars=self.total_chars + other.total_chars,
total_lines=self.total_lines + other.total_lines,
snippet_count=self.snippet_count + other.snippet_count,
)
@dataclass
class FileMatchDetail:
"""Detailed matching info for a single expected file."""
file_path: str
found: bool # Was the file in returned results?
full_match: bool # Did it contain ALL required substrings?
substrings_required: int
substrings_found: int
missing_substrings: List[str] = field(default_factory=list)
@property
def substring_coverage(self) -> float:
"""What fraction of required substrings were found."""
if self.substrings_required == 0:
return 1.0
return self.substrings_found / self.substrings_required
@property
def is_partial_match(self) -> bool:
"""Found file but not all content."""
return self.found and not self.full_match
@dataclass
class QueryMetrics:
"""Metrics for a single query evaluation."""
query_id: str
precision: float
recall: float
f1: float
success: bool
token_stats: TokenStats = field(default_factory=TokenStats)
# Detailed counts for debugging
true_positives: int = 0
false_positives: int = 0
false_negatives: int = 0
# Timing metrics (in milliseconds)
execution_time_ms: float = 0.0
# NEW: Granular metrics
files_expected: int = 0
files_found: int = 0 # Files in results (regardless of content)
files_fully_matched: int = 0 # Files with ALL substrings
files_partially_matched: int = 0 # Files found but missing some substrings
total_substrings_required: int = 0
total_substrings_found: int = 0
# Detailed per-file breakdown
file_details: List[FileMatchDetail] = field(default_factory=list)
@property
def success_float(self) -> float:
"""Success as float (1.0 or 0.0)."""
return 1.0 if self.success else 0.0
@property
def execution_time_seconds(self) -> float:
"""Execution time in seconds."""
return self.execution_time_ms / 1000.0
@property
def meets_latency_target(self) -> bool:
"""Check if execution time is under 10 seconds (task requirement)."""
return self.execution_time_seconds < 10.0
@property
def file_discovery_rate(self) -> float:
"""What fraction of expected files were found (any match)."""
if self.files_expected == 0:
return 0.0
return self.files_found / self.files_expected
@property
def substring_coverage(self) -> float:
"""What fraction of all required substrings were found."""
if self.total_substrings_required == 0:
return 0.0
return self.total_substrings_found / self.total_substrings_required
@property
def partial_match_rate(self) -> float:
"""Fraction of expected files that were partial matches."""
if self.files_expected == 0:
return 0.0
return self.files_partially_matched / self.files_expected
@dataclass
class AggregatedMetrics:
"""Aggregated metrics across all queries."""
num_queries: int
# Macro-averaged metrics (average of per-query metrics)
macro_precision: float
macro_recall: float
macro_f1: float
success_rate: float
# Token statistics
total_token_stats: TokenStats
# Timing statistics (in milliseconds)
total_time_ms: float = 0.0
avg_time_ms: float = 0.0
min_time_ms: float = 0.0
max_time_ms: float = 0.0
queries_meeting_latency_target: int = 0
# NEW: Aggregated granular metrics
avg_file_discovery_rate: float = 0.0
avg_substring_coverage: float = 0.0
avg_partial_match_rate: float = 0.0
total_files_expected: int = 0
total_files_found: int = 0
total_files_fully_matched: int = 0
total_files_partially_matched: int = 0
# Per-query results for detailed analysis
query_results: List[QueryMetrics] = field(default_factory=list)
@property
def avg_chars_per_query(self) -> float:
"""Average characters returned per query."""
if self.num_queries == 0:
return 0.0
return self.total_token_stats.total_chars / self.num_queries
@property
def avg_lines_per_query(self) -> float:
"""Average lines returned per query."""
if self.num_queries == 0:
return 0.0
return self.total_token_stats.total_lines / self.num_queries
@property
def avg_snippets_per_query(self) -> float:
"""Average number of snippets returned per query."""
if self.num_queries == 0:
return 0.0
return self.total_token_stats.snippet_count / self.num_queries
@property
def latency_target_rate(self) -> float:
"""Fraction of queries meeting the <10s latency target."""
if self.num_queries == 0:
return 0.0
return self.queries_meeting_latency_target / self.num_queries
@property
def avg_time_seconds(self) -> float:
"""Average execution time in seconds."""
return self.avg_time_ms / 1000.0
@property
def total_time_seconds(self) -> float:
"""Total execution time in seconds."""
return self.total_time_ms / 1000.0
@property
def overall_file_discovery_rate(self) -> float:
"""Overall file discovery rate across all queries."""
if self.total_files_expected == 0:
return 0.0
return self.total_files_found / self.total_files_expected
def is_perfect(self) -> bool:
"""
Check if results are perfect (Precision=1.0 and Recall=1.0 for all queries).
"""
return (
abs(self.macro_precision - 1.0) < 1e-9 and
abs(self.macro_recall - 1.0) < 1e-9
)
def calculate_precision(true_positives: int, false_positives: int) -> float:
"""Calculate Precision@K = TP / (TP + FP)"""
total = true_positives + false_positives
if total == 0:
return 0.0
return true_positives / total
def calculate_recall(true_positives: int, false_negatives: int) -> float:
"""Calculate Recall@K = TP / (TP + FN)"""
total = true_positives + false_negatives
if total == 0:
return 0.0
return true_positives / total
def calculate_f1(precision: float, recall: float) -> float:
"""Calculate F1 = 2 * P * R / (P + R)"""
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def aggregate_metrics(query_metrics: List[QueryMetrics]) -> AggregatedMetrics:
"""Aggregate metrics across all queries using macro-averaging."""
if not query_metrics:
return AggregatedMetrics(
num_queries=0,
macro_precision=0.0,
macro_recall=0.0,
macro_f1=0.0,
success_rate=0.0,
total_token_stats=TokenStats(),
total_time_ms=0.0,
avg_time_ms=0.0,
min_time_ms=0.0,
max_time_ms=0.0,
queries_meeting_latency_target=0,
query_results=[],
)
n = len(query_metrics)
total_precision = sum(qm.precision for qm in query_metrics)
total_recall = sum(qm.recall for qm in query_metrics)
total_f1 = sum(qm.f1 for qm in query_metrics)
total_success = sum(1 for qm in query_metrics if qm.success)
total_token_stats = TokenStats()
for qm in query_metrics:
total_token_stats = total_token_stats + qm.token_stats
# Timing statistics
times = [qm.execution_time_ms for qm in query_metrics]
total_time_ms = sum(times)
avg_time_ms = total_time_ms / n
min_time_ms = min(times) if times else 0.0
max_time_ms = max(times) if times else 0.0
queries_meeting_target = sum(1 for qm in query_metrics if qm.meets_latency_target)
# NEW: Aggregate granular metrics
total_file_discovery = sum(qm.file_discovery_rate for qm in query_metrics)
total_substring_coverage = sum(qm.substring_coverage for qm in query_metrics)
total_partial_rate = sum(qm.partial_match_rate for qm in query_metrics)
total_files_expected = sum(qm.files_expected for qm in query_metrics)
total_files_found = sum(qm.files_found for qm in query_metrics)
total_files_fully_matched = sum(qm.files_fully_matched for qm in query_metrics)
total_files_partially_matched = sum(qm.files_partially_matched for qm in query_metrics)
return AggregatedMetrics(
num_queries=n,
macro_precision=total_precision / n,
macro_recall=total_recall / n,
macro_f1=total_f1 / n,
success_rate=total_success / n,
total_token_stats=total_token_stats,
total_time_ms=total_time_ms,
avg_time_ms=avg_time_ms,
min_time_ms=min_time_ms,
max_time_ms=max_time_ms,
queries_meeting_latency_target=queries_meeting_target,
# NEW granular metrics
avg_file_discovery_rate=total_file_discovery / n,
avg_substring_coverage=total_substring_coverage / n,
avg_partial_match_rate=total_partial_rate / n,
total_files_expected=total_files_expected,
total_files_found=total_files_found,
total_files_fully_matched=total_files_fully_matched,
total_files_partially_matched=total_files_partially_matched,
query_results=query_metrics,
)
def _calculate_std(values: List[float], mean: float) -> float:
"""Calculate standard deviation."""
if len(values) < 2:
return 0.0
variance = sum((v - mean) ** 2 for v in values) / (len(values) - 1)
return math.sqrt(variance)
def _calculate_cv(std: float, mean: float) -> float:
"""Calculate coefficient of variation (CV = std/mean)."""
if mean == 0:
return 0.0
return std / mean
@dataclass
class MetricStability:
"""Stability statistics for a single metric across multiple runs."""
mean: float
std: float
min_val: float
max_val: float
cv: float # Coefficient of variation
@property
def range(self) -> float:
"""Range of values (max - min)."""
return self.max_val - self.min_val
@property
def is_stable(self) -> bool:
"""Consider metric stable if CV < 10%."""
return self.cv < 0.1
def __str__(self) -> str:
return f"{self.mean:.4f} ± {self.std:.4f} (CV={self.cv:.1%})"
@dataclass
class QueryStabilityMetrics:
"""Stability metrics for a single query evaluated multiple times."""
query_id: str
num_runs: int
precision: MetricStability
recall: MetricStability
f1: MetricStability
success_rate: MetricStability
execution_time: MetricStability
file_discovery_rate: MetricStability
substring_coverage: MetricStability
# All individual run results
run_results: List[QueryMetrics] = field(default_factory=list)
@property
def overall_stability_score(self) -> float:
"""
Overall stability score (0-1, higher is better).
Based on average of (1 - CV) for key metrics.
"""
metrics = [
self.precision,
self.recall,
self.f1,
self.file_discovery_rate,
self.substring_coverage,
]
scores = [max(0, 1 - m.cv) for m in metrics]
return sum(scores) / len(scores)
@property
def is_stable(self) -> bool:
"""Check if all key metrics are stable."""
return (
self.precision.is_stable and
self.recall.is_stable and
self.f1.is_stable
)
@dataclass
class AggregatedStabilityMetrics:
"""Aggregated stability metrics across all queries."""
num_queries: int
num_runs_per_query: int
# Aggregated stability for each metric
precision: MetricStability
recall: MetricStability
f1: MetricStability
success_rate: MetricStability
execution_time: MetricStability
file_discovery_rate: MetricStability
substring_coverage: MetricStability
# Average stability score across all queries
avg_stability_score: float
stable_queries_count: int
# Per-query stability results
query_stability_results: List[QueryStabilityMetrics] = field(default_factory=list)
def calculate_metric_stability(values: List[float]) -> MetricStability:
"""Calculate stability metrics for a list of values."""
if not values:
return MetricStability(mean=0.0, std=0.0, min_val=0.0, max_val=0.0, cv=0.0)
mean = sum(values) / len(values)
std = _calculate_std(values, mean)
cv = _calculate_cv(std, mean)
return MetricStability(
mean=mean,
std=std,
min_val=min(values),
max_val=max(values),
cv=cv,
)
def aggregate_query_runs(query_id: str, run_results: List[QueryMetrics]) -> QueryStabilityMetrics:
"""Aggregate multiple runs of the same query into stability metrics."""
if not run_results:
empty_stability = MetricStability(mean=0.0, std=0.0, min_val=0.0, max_val=0.0, cv=0.0)
return QueryStabilityMetrics(
query_id=query_id,
num_runs=0,
precision=empty_stability,
recall=empty_stability,
f1=empty_stability,
success_rate=empty_stability,
execution_time=empty_stability,
file_discovery_rate=empty_stability,
substring_coverage=empty_stability,
run_results=[],
)
return QueryStabilityMetrics(
query_id=query_id,
num_runs=len(run_results),
precision=calculate_metric_stability([r.precision for r in run_results]),
recall=calculate_metric_stability([r.recall for r in run_results]),
f1=calculate_metric_stability([r.f1 for r in run_results]),
success_rate=calculate_metric_stability([r.success_float for r in run_results]),
execution_time=calculate_metric_stability([r.execution_time_ms for r in run_results]),
file_discovery_rate=calculate_metric_stability([r.file_discovery_rate for r in run_results]),
substring_coverage=calculate_metric_stability([r.substring_coverage for r in run_results]),
run_results=run_results,
)
def aggregate_stability_metrics(
query_stabilities: List[QueryStabilityMetrics],
num_runs_per_query: int,
) -> AggregatedStabilityMetrics:
"""Aggregate stability metrics across all queries."""
if not query_stabilities:
empty_stability = MetricStability(mean=0.0, std=0.0, min_val=0.0, max_val=0.0, cv=0.0)
return AggregatedStabilityMetrics(
num_queries=0,
num_runs_per_query=num_runs_per_query,
precision=empty_stability,
recall=empty_stability,
f1=empty_stability,
success_rate=empty_stability,
execution_time=empty_stability,
file_discovery_rate=empty_stability,
substring_coverage=empty_stability,
avg_stability_score=0.0,
stable_queries_count=0,
query_stability_results=[],
)
# Collect mean values across all queries
precision_means = [qs.precision.mean for qs in query_stabilities]
recall_means = [qs.recall.mean for qs in query_stabilities]
f1_means = [qs.f1.mean for qs in query_stabilities]
success_rate_means = [qs.success_rate.mean for qs in query_stabilities]
execution_time_means = [qs.execution_time.mean for qs in query_stabilities]
file_discovery_means = [qs.file_discovery_rate.mean for qs in query_stabilities]
substring_coverage_means = [qs.substring_coverage.mean for qs in query_stabilities]
# Calculate stability scores
stability_scores = [qs.overall_stability_score for qs in query_stabilities]
stable_count = sum(1 for qs in query_stabilities if qs.is_stable)
return AggregatedStabilityMetrics(
num_queries=len(query_stabilities),
num_runs_per_query=num_runs_per_query,
precision=calculate_metric_stability(precision_means),
recall=calculate_metric_stability(recall_means),
f1=calculate_metric_stability(f1_means),
success_rate=calculate_metric_stability(success_rate_means),
execution_time=calculate_metric_stability(execution_time_means),
file_discovery_rate=calculate_metric_stability(file_discovery_means),
substring_coverage=calculate_metric_stability(substring_coverage_means),
avg_stability_score=sum(stability_scores) / len(stability_scores),
stable_queries_count=stable_count,
query_stability_results=query_stabilities,
)