Skip to main content
Glama
ab_testing.py8.19 kB
#!/usr/bin/env python3 """A/B 测试框架 - 用于实验不同分组参数""" import json import sys from datetime import datetime # 确保可以导入项目模块 sys.path.insert(0, "src") from paperlib_mcp.db import get_db, query_all, query_one def collect_metrics() -> dict: """收集当前分组状态的指标""" basic = query_one(""" SELECT (SELECT COUNT(*) FROM claims) as total_claims, (SELECT COUNT(*) FROM claim_groups) as total_groups, (SELECT COUNT(*) FROM claim_group_members) as total_members """) stats = query_one(""" SELECT COUNT(DISTINCT m.claim_id)::float / NULLIF((SELECT COUNT(*) FROM claims), 0) as coverage, MAX(n) as max_group_size, AVG(n) as avg_group_size, percentile_cont(0.5) WITHIN GROUP (ORDER BY n) as p50_group_size, percentile_cont(0.9) WITHIN GROUP (ORDER BY n) as p90_group_size FROM claim_group_members m JOIN (SELECT group_id, COUNT(*) as n FROM claim_group_members GROUP BY group_id) t ON t.group_id = m.group_id """) tax_hit = query_one(""" SELECT 1 - COUNT(*)::float / NULLIF((SELECT COUNT(*) FROM claim_features), 0) as hit_rate FROM claim_features WHERE outcome_family = 'general' """) subgroups = query_one("SELECT COUNT(*) as n FROM claim_groups WHERE parent_group_id IS NOT NULL") return { "total_claims": basic["total_claims"], "groups_created": basic["total_groups"], "total_members": basic["total_members"], "coverage": float(stats["coverage"]) if stats["coverage"] else 0, "max_group_size": stats["max_group_size"], "avg_group_size": float(stats["avg_group_size"]) if stats["avg_group_size"] else 0, "p50_group_size": float(stats["p50_group_size"]) if stats["p50_group_size"] else 0, "p90_group_size": float(stats["p90_group_size"]) if stats["p90_group_size"] else 0, "taxonomy_hit_rate": float(tax_hit["hit_rate"]) if tax_hit["hit_rate"] else 0, "subgroups": subgroups["n"], } def create_experiment(experiment_id: str, description: str, **params): """创建实验配置""" params_json = json.dumps(params, sort_keys=True) with get_db() as conn: with conn.cursor() as cur: cur.execute(""" INSERT INTO experiments (experiment_id, description, params_json) VALUES (%s, %s, %s::jsonb) ON CONFLICT (experiment_id) DO UPDATE SET description = EXCLUDED.description, params_json = EXCLUDED.params_json """, (experiment_id, description, params_json)) print(f"✓ Created experiment: {experiment_id}") return params def run_experiment(experiment_id: str): """运行实验并记录指标""" exp = query_one("SELECT params_json FROM experiments WHERE experiment_id = %s", (experiment_id,)) if not exp: print(f"✗ Experiment {experiment_id} not found") return None metrics = collect_metrics() with get_db() as conn: with conn.cursor() as cur: cur.execute(""" INSERT INTO experiment_runs (experiment_id, metrics) VALUES (%s, %s::jsonb) RETURNING run_id """, (experiment_id, json.dumps(metrics))) run_id = cur.fetchone()["run_id"] print(f"✓ Run #{run_id} recorded for {experiment_id}") return {"run_id": run_id, "metrics": metrics} def rate_run(run_id: int, rating: int, notes: str = None): """为运行评分""" with get_db() as conn: with conn.cursor() as cur: cur.execute(""" UPDATE experiment_runs SET user_rating = %s, user_notes = %s WHERE run_id = %s """, (rating, notes, run_id)) print(f"✓ Rated run #{run_id}: {rating}/5") def compare_experiments(experiment_ids: list = None): """对比实验结果""" where = "" params = () if experiment_ids: where = "WHERE e.experiment_id = ANY(%s)" params = (experiment_ids,) rows = query_all(f""" SELECT e.experiment_id, e.description, r.run_id, r.metrics, r.user_rating, r.created_at FROM experiments e LEFT JOIN experiment_runs r ON r.experiment_id = e.experiment_id {where} ORDER BY e.experiment_id, r.created_at DESC """, params) print("\n" + "=" * 80) print("EXPERIMENT COMPARISON") print("=" * 80) current_exp = None for row in rows: if row["experiment_id"] != current_exp: current_exp = row["experiment_id"] print(f"\n📊 {row['experiment_id']}: {row['description']}") if row["run_id"]: m = row["metrics"] rating = f"⭐{row['user_rating']}" if row["user_rating"] else "unrated" print(f" Run #{row['run_id']} ({rating})") print(f" coverage: {m['coverage']:.2%}, groups: {m['groups_created']}, max: {m['max_group_size']}") print(f" taxonomy: {m['taxonomy_hit_rate']:.2%}, subgroups: {m['subgroups']}") def list_experiments(): """列出所有实验""" rows = query_all(""" SELECT e.experiment_id, e.description, e.params_json->>'split_threshold' as split_threshold, COUNT(r.run_id) as run_count, AVG(r.user_rating) as avg_rating FROM experiments e LEFT JOIN experiment_runs r ON r.experiment_id = e.experiment_id GROUP BY e.experiment_id, e.description, e.params_json ORDER BY e.created_at DESC """) print("\n" + "=" * 60) print("EXPERIMENTS") print("=" * 60) for r in rows: rating = f"⭐{r['avg_rating']:.1f}" if r["avg_rating"] else "—" print(f" {r['experiment_id']}: {r['description']}") print(f" threshold={r['split_threshold']}, runs={r['run_count']}, {rating}") # ============================================================ # CLI 入口 # ============================================================ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="A/B Testing Framework") subparsers = parser.add_subparsers(dest="command") # create p_create = subparsers.add_parser("create", help="Create experiment") p_create.add_argument("id", help="Experiment ID") p_create.add_argument("description", help="Description") p_create.add_argument("--split", type=int, default=150, help="Split threshold") # run p_run = subparsers.add_parser("run", help="Run experiment") p_run.add_argument("id", help="Experiment ID") # rate p_rate = subparsers.add_parser("rate", help="Rate a run") p_rate.add_argument("run_id", type=int, help="Run ID") p_rate.add_argument("rating", type=int, choices=[1,2,3,4,5], help="Rating 1-5") p_rate.add_argument("--notes", help="Notes") # compare p_compare = subparsers.add_parser("compare", help="Compare experiments") p_compare.add_argument("ids", nargs="*", help="Experiment IDs to compare") # list subparsers.add_parser("list", help="List experiments") # metrics subparsers.add_parser("metrics", help="Show current metrics") args = parser.parse_args() if args.command == "create": create_experiment(args.id, args.description, split_threshold=args.split) elif args.command == "run": result = run_experiment(args.id) if result: print("\nMetrics:") for k, v in result["metrics"].items(): print(f" {k}: {v}") elif args.command == "rate": rate_run(args.run_id, args.rating, args.notes) elif args.command == "compare": compare_experiments(args.ids if args.ids else None) elif args.command == "list": list_experiments() elif args.command == "metrics": print("\nCurrent Metrics:") for k, v in collect_metrics().items(): print(f" {k}: {v}") else: parser.print_help()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/h-lu/paperlib-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server