visualize_benchmarks.py•3.1 kB
import json
import matplotlib.pyplot as plt
import numpy as np
import sys
def load_benchmark_data(filename):
    """Load benchmark data from JSON file."""
    with open(filename, "r") as f:
        return json.load(f)
def visualize_benchmarks(benchmark_file, output_file=None):
    """Visualize benchmark results with error bars."""
    # Load data
    data = load_benchmark_data(benchmark_file)
    # Define metrics to plot
    metrics = ["Human-like Correctness", "DeepEval Correctness", "DeepEval EM", "DeepEval F1"]
    colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
    # Setup plot
    fig, ax = plt.subplots(figsize=(14, 8))
    # Get system names
    systems = [system["system"] for system in data]
    x_pos = np.arange(len(systems))
    # Plot each metric
    for i, metric in enumerate(metrics):
        means = []
        errors_lower = []
        errors_upper = []
        for system in data:
            if metric in system:
                means.append(system[metric])
                error_key = f"{metric} Error"
                if error_key in system:
                    errors_lower.append(system[metric] - system[error_key][0])
                    errors_upper.append(system[error_key][1] - system[metric])
                else:
                    errors_lower.append(0)
                    errors_upper.append(0)
            else:
                means.append(0)
                errors_lower.append(0)
                errors_upper.append(0)
        # Plot bars with error bars
        ax.bar(x_pos + i * 0.2, means, 0.2, label=metric, color=colors[i], alpha=0.8)
        # Add error bars
        for j, (mean, err_lower, err_upper) in enumerate(zip(means, errors_lower, errors_upper)):
            if mean > 0:  # Only show error bars for non-zero values
                ax.errorbar(
                    x_pos[j] + i * 0.2,
                    mean,
                    yerr=[[err_lower], [err_upper]],
                    fmt="none",
                    color="black",
                    capsize=3,
                    capthick=1,
                )
    # Customize plot
    ax.set_xlabel("Systems", fontsize=12)
    ax.set_ylabel("Score", fontsize=12)
    ax.set_title("Benchmark Results", fontsize=14, fontweight="bold")
    ax.set_xticks(x_pos + 0.3)  # Center the x-ticks
    ax.set_xticklabels(systems, rotation=45, ha="right")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1.1)
    # Adjust layout
    plt.tight_layout()
    # Save or show
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches="tight")
        print(f"Plot saved as {output_file}")
    else:
        plt.show()
if __name__ == "__main__":
    # Hardcoded benchmark files
    benchmark_file = "benchmark_summary_competition.json"
    # benchmark_file = "benchmark_summary_cognee.json"
    # Comment out which one you want to visualize
    # visualize_benchmarks(competition_file, competition_file.replace('.json', '.png'))
    visualize_benchmarks(benchmark_file, benchmark_file.replace(".json", ".png"))