plot_metrics.py•4.61 kB
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import numpy as np
INPUT_COGNEE = Path("evals/benchmark_summary_cognee.json")
INPUT_COMPETITION = Path("evals/benchmark_summary_competition.json")
OUT_OPTIMISED = "evals/optimized_cognee_configurations.png"
OUT_COMP = "evals/comprehensive_metrics_comparison.png"
# Metric id ➜ bar colour (keep in same order for legend)
METRIC_KEYS = {
"Human-like Correctness": "#4ade80", # Green
"DeepEval Correctness": "#818cf8", # Indigo
"DeepEval F1": "#c084fc", # Light purple
"DeepEval EM": "#6b7280", # Grey
}
Y_LIM = (0.0, 1.05) # applies to all charts
def _load(path: Path) -> List[Dict[str, Any]]:
"""Read JSON file that may be either a list or dict{'data': …}."""
with path.open() as f:
obj = json.load(f)
if isinstance(obj, list):
return obj
if isinstance(obj, dict) and "data" in obj and isinstance(obj["data"], list):
return obj["data"]
raise ValueError(f"Unsupported format in {path}")
def _extract_matrix(records: List[Dict[str, Any]]):
"""
Return:
systems -> list[str]
means -> dict[metric] = array(len(systems))
error_minus -> dict[metric] = array(len(systems))
error_plus -> dict[metric] = array(len(systems))
Any missing value is filled with 0.
"""
systems = [r["system"] for r in records]
means, err_m, err_p = {}, {}, {}
for metric in METRIC_KEYS:
m, e_m, e_p = [], [], []
for r in records:
mean = r.get(metric, 0.0)
low, high = r.get(f"{metric} Error", [mean, mean])
m.append(mean)
e_m.append(mean - low)
e_p.append(high - mean)
means[metric] = np.asarray(m)
err_m[metric] = np.asarray(e_m)
err_p[metric] = np.asarray(e_p)
return systems, means, err_m, err_p
def _plot_grouped_bar(
systems: List[str],
means: Dict[str, np.ndarray],
err_m: Dict[str, np.ndarray],
err_p: Dict[str, np.ndarray],
title: str,
outfile: str,
rotate_xticks: bool = False,
) -> None:
n_metrics = len(METRIC_KEYS)
ind = np.arange(len(systems))
width = 0.8 / n_metrics
fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
ax.set_ylim(*Y_LIM)
ax.set_title(title, fontsize=16, fontweight="bold", pad=15)
ax.set_ylabel("Score")
ax.set_xticks(ind)
ha = "right" if rotate_xticks else "center"
ax.set_xticklabels(
systems,
rotation=15 if rotate_xticks else 0,
ha=ha,
)
for i, (metric, colour) in enumerate(METRIC_KEYS.items()):
offset = ind + (i - (n_metrics - 1) / 2) * width
ax.bar(
offset,
means[metric],
width,
label=metric,
color=colour,
yerr=[err_m[metric], err_p[metric]],
capsize=4,
ecolor="#374151",
)
# value labels
for x, y in zip(offset, means[metric]):
if y > 0:
ax.text(x, y + 0.02, f"{y:.2f}", ha="center", va="bottom", fontsize=8)
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(outfile)
plt.close(fig)
def main() -> None:
# Allow overriding the default locations via CLI arguments
cognee_file = Path(sys.argv[1]) if len(sys.argv) > 1 else INPUT_COGNEE
comp_file = Path(sys.argv[2]) if len(sys.argv) > 2 else INPUT_COMPETITION
if not cognee_file.exists():
raise FileNotFoundError(f"{cognee_file} not found")
if not comp_file.exists():
raise FileNotFoundError(f"{comp_file} not found")
# Optimised Cognee configurations
cfg_records = _load(cognee_file)
systems, means, err_m, err_p = _extract_matrix(cfg_records)
_plot_grouped_bar(
systems,
means,
err_m,
err_p,
title="Optimized Cognee Configurations",
outfile=OUT_OPTIMISED,
rotate_xticks=True,
)
print(f"Wrote {OUT_OPTIMISED}")
# Cognee vs. competition
comp_records = _load(comp_file)
for record in comp_records:
if record.get("system") == "Graphiti":
record["system"] = "Graphiti (Previous Results)"
systems, means, err_m, err_p = _extract_matrix(comp_records)
_plot_grouped_bar(
systems,
means,
err_m,
err_p,
title="Comprehensive Metrics Comparison",
outfile=OUT_COMP,
)
print(f"Wrote {OUT_COMP}")
if __name__ == "__main__":
main()