import argparse import os from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd from utils import load_train_config # ------------------------- # Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS) # ------------------------- PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0) PRIMARY_CAPTURE_TOPKS = (20, 50) PRIMARY_WY_TAU = 1.0 PRIMARY_WY_FRACS = (0.05, 0.10) PRIMARY_METRICS = ( ("C1", "capture", {"topk": 20, "tau": 1.0}), ("C2", "capture", {"topk": 20, "tau": 2.0}), ("C3", "capture", {"topk": 20, "tau": 5.0}), ("C4", "capture", {"topk": 50, "tau": 1.0}), ("C5", "capture", {"topk": 50, "tau": 2.0}), ("C6", "capture", {"topk": 50, "tau": 5.0}), ("W1", "workload", {"frac": 0.05, "tau": 1.0}), ("W2", "workload", {"frac": 0.10, "tau": 1.0}), ) EXPECTED_AGE_BINS = ( "[40.0, 45.0)", "[45.0, 50.0)", "[50.0, 55.0)", "[55.0, 60.0)", "[60.0, 65.0)", "[65.0, 70.0)", "[70.0, inf)", ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Compare trained models using 8 pre-registered primary metrics." ) p.add_argument( "--runs_root", type=str, default="runs", help="Directory containing run subfolders (default: runs)", ) p.add_argument( "--pattern", type=str, default="*", help="Glob pattern for run folder names under runs_root (default: *)", ) p.add_argument( "--run_dirs_file", type=str, default="", help="Optional file with one run_dir per line (overrides runs_root/pattern)", ) p.add_argument( "--out_dir", type=str, default="comparison", help="Output directory for tables (default: comparison)", ) p.add_argument( "--expect_n", type=int, default=12, help="Expected number of models (default: 12). Set 0 to disable check.", ) return p.parse_args() def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]: if not os.path.isdir(runs_root): raise FileNotFoundError(f"runs_root not found: {runs_root}") import glob out: List[str] = [] for d in glob.glob(os.path.join(runs_root, pattern)): if not os.path.isdir(d): continue if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))): continue out.append(d) out.sort() return out def _read_run_dirs_file(path: str) -> List[str]: if not os.path.isfile(path): raise FileNotFoundError(f"run_dirs_file not found: {path}") out: List[str] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue out.append(line) return out def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool: return abs(float(a) - float(b)) <= tol def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]: """Returns (event_weighted, macro, total_denom_events).""" sub = df[(df["topk"] == int(topk)) & ( df["tau_years"].astype(float) == float(tau))].copy() if sub.empty: return float("nan"), float("nan"), 0 denom = sub["denom_events"].astype(np.int64) numer = sub["numer_events"].astype(np.int64) total_denom = int(denom.sum()) if total_denom > 0: event_weighted = float(numer.sum() / total_denom) else: event_weighted = float("nan") sub_pos = sub[denom > 0] macro = float(sub_pos["capture_at_k"].astype( float).mean()) if not sub_pos.empty else float("nan") return event_weighted, macro, total_denom def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]: """Returns (event_weighted, macro, total_events).""" sub = df[(df["tau_years"].astype(float) == float(tau)) & ( df["frac_selected"].astype(float) == float(frac))].copy() if sub.empty: return float("nan"), float("nan"), 0 total_events = sub["total_events"].astype(np.int64) captured = sub["events_captured"].astype(np.int64) denom = int(total_events.sum()) if denom > 0: event_weighted = float(captured.sum() / denom) else: event_weighted = float("nan") sub_pos = sub[total_events > 0] macro = float(sub_pos["capture_rate"].astype( float).mean()) if not sub_pos.empty else float("nan") return event_weighted, macro, denom def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame: """Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name.""" df = pd.DataFrame({"model_name": model_names, "value": values.astype(float)}) df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64)) df = df.sort_values(["is_nan", "value", "model_name"], ascending=[ True, False, True], kind="mergesort") df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64) n = len(df) df["points"] = (n - df["rank"] + 1).astype(np.int64) # Force NaNs to last (1 point) df.loc[df["is_nan"], "points"] = 1 return df[["model_name", "rank", "points"]] def main() -> None: args = parse_args() if args.run_dirs_file: run_dirs = _read_run_dirs_file(args.run_dirs_file) else: run_dirs = _discover_run_dirs(args.runs_root, args.pattern) if args.expect_n and len(run_dirs) != int(args.expect_n): raise ValueError( f"Expected {args.expect_n} runs, found {len(run_dirs)}") os.makedirs(args.out_dir, exist_ok=True) rows: List[Dict[str, object]] = [] warnings: List[str] = [] for run_dir in run_dirs: model_name = os.path.basename(os.path.normpath(run_dir)) cfg = load_train_config(run_dir) loss_type = str(cfg.get("loss_type", "")) age_encoder = str( cfg.get("age_encoder", cfg.get("age_encoder_type", ""))) full_cov = bool(cfg.get("full_cov", False)) covariate_setting = "fullcov" if full_cov else "partcov" cap_path = os.path.join(run_dir, "horizon_capture.csv") wy_path = os.path.join(run_dir, "workload_yield.csv") if not os.path.isfile(cap_path) or not os.path.isfile(wy_path): warnings.append( f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv") cap_df = pd.DataFrame() wy_df = pd.DataFrame() else: cap_df = pd.read_csv(cap_path) wy_df = pd.read_csv(wy_path) # Validate age bins (best-effort). Missing bins are allowed (no eligible records), # but unexpected bins likely indicate protocol mismatch. if not cap_df.empty and "age_bin" in cap_df.columns: seen_bins = sorted(set(str(x) for x in cap_df["age_bin"].dropna().unique())) unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS] if unexpected: warnings.append( f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}") out: Dict[str, object] = { "model_name": model_name, "run_dir": run_dir, "loss_type": loss_type, "age_encoder": age_encoder, "covariate_setting": covariate_setting, } # Primary panel: event-weighted metrics for metric_id, kind, spec in PRIMARY_METRICS: if kind == "capture": v_w, v_macro, denom = _capture_values( cap_df, topk=spec["topk"], tau=spec["tau"]) out[metric_id] = v_w out[f"{metric_id}_macro"] = v_macro out[f"{metric_id}_denom"] = int(denom) elif kind == "workload": v_w, v_macro, denom = _workload_values( wy_df, frac=spec["frac"], tau=spec["tau"]) out[metric_id] = v_w out[f"{metric_id}_macro"] = v_macro out[f"{metric_id}_denom"] = int(denom) else: raise ValueError(f"Unknown metric kind: {kind}") rows.append(out) metrics_df = pd.DataFrame(rows) # Scoring: ranks + points for the 8 primary metrics (event-weighted) per_metric_rank_parts: List[pd.DataFrame] = [] total_points = np.zeros((len(metrics_df),), dtype=np.int64) for metric_id, _, _ in PRIMARY_METRICS: r = _rank_points(metrics_df[metric_id], metrics_df["model_name"]) r = r.rename(columns={"rank": f"{metric_id}_rank", "points": f"{metric_id}_points"}) per_metric_rank_parts.append(r) ranks_df = per_metric_rank_parts[0] for part in per_metric_rank_parts[1:]: ranks_df = ranks_df.merge(part, on="model_name", how="left") # Merge ranks back merged = metrics_df.merge(ranks_df, on="model_name", how="left") point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS] merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64) # Macro robustness score (not used for ranking, but used to flag instability) macro_point_parts: List[pd.DataFrame] = [] for metric_id, _, _ in PRIMARY_METRICS: r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"]) r = r.rename(columns={"rank": f"{metric_id}_macro_rank", "points": f"{metric_id}_macro_points"}) macro_point_parts.append(r) macro_ranks_df = macro_point_parts[0] for part in macro_point_parts[1:]: macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left") merged = merged.merge(macro_ranks_df, on="model_name", how="left") macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS] merged["macro_total_score"] = merged[macro_point_cols].sum( axis=1).astype(np.int64) # Final leaderboard ranks merged = merged.sort_values(["total_score", "model_name"], ascending=[ False, True], kind="mergesort") merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64) merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[ False, True], kind="mergesort") macro_rank_map = { name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())} merged["macro_leaderboard_rank"] = merged["model_name"].map( macro_rank_map).astype(np.int64) merged["rank_delta_macro_minus_primary"] = ( merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64) merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4 # Output tables leaderboard_cols = [ "leaderboard_rank", "model_name", "loss_type", "age_encoder", "covariate_setting", "total_score", "macro_total_score", "macro_leaderboard_rank", "rank_delta_macro_minus_primary", "macro_collapse_flag", ] for mid, _, _ in PRIMARY_METRICS: leaderboard_cols.append(mid) out_leaderboard = merged[leaderboard_cols] out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv") out_leaderboard.to_csv(out_leaderboard_path, index=False) # Audit table: per-metric ranks/points + values audit_cols = ["model_name"] for mid, _, _ in PRIMARY_METRICS: audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"]) audit_cols.append("total_score") out_audit = merged[audit_cols] out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv") out_audit.to_csv(out_audit_path, index=False) # Dimension summaries dim_summary_parts: List[pd.DataFrame] = [] for dim in ["loss_type", "age_encoder", "covariate_setting"]: g = merged.groupby(dim, dropna=False) dim_summary_parts.append( g.agg( n_models=("model_name", "count"), mean_total_score=("total_score", "mean"), std_total_score=("total_score", "std"), best_total_score=("total_score", "max"), ).reset_index() ) dim_summary = pd.concat( [df.assign(dimension=name) for df, name in zip(dim_summary_parts, [ "loss_type", "age_encoder", "covariate_setting"])], ignore_index=True, ) out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv") dim_summary.to_csv(out_dim_path, index=False) # Matched-pair deltas for interpretability def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame: cols = list(fixed) + [var, "model_name", "total_score"] df = merged[cols].copy() pairs: List[Dict[str, object]] = [] for key, grp in df.groupby(list(fixed), dropna=False): if len(grp) < 2: continue if grp[var].nunique() < 2: continue # deterministic ordering grp = grp.sort_values([var, "model_name"], kind="mergesort") base = grp.iloc[0] for i in range(1, len(grp)): other = grp.iloc[i] pairs.append( { **({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}), "var_dim": var, "a_var": base[var], "b_var": other[var], "a_model": base["model_name"], "b_model": other["model_name"], "a_score": int(base["total_score"]), "b_score": int(other["total_score"]), "delta_b_minus_a": int(other["total_score"] - base["total_score"]), } ) return pd.DataFrame(pairs) out_pairs_cov = _pair_deltas( ["loss_type", "age_encoder"], "covariate_setting") out_pairs_age = _pair_deltas( ["loss_type", "covariate_setting"], "age_encoder") out_pairs_loss = _pair_deltas( ["age_encoder", "covariate_setting"], "loss_type") out_pairs_cov.to_csv(os.path.join( args.out_dir, "pairs_covariates.csv"), index=False) out_pairs_age.to_csv(os.path.join( args.out_dir, "pairs_age_encoder.csv"), index=False) out_pairs_loss.to_csv(os.path.join( args.out_dir, "pairs_loss_type.csv"), index=False) # Protocol stamp for auditability protocol = { "PRIMARY_METRICS": list(PRIMARY_METRICS), "EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS), "PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS), "PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS), "PRIMARY_WY_TAU": float(PRIMARY_WY_TAU), "PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS), "aggregation_primary": "event-weighted by denom_events / total_events", "aggregation_macro": "macro-average over age bins with denom>0", "nan_policy": "NaN ranked last (1 point)", "tie_break": "model_name ascending (deterministic)", } pd.Series(protocol).to_json(os.path.join(args.out_dir, "protocol.json"), indent=2, force_ascii=False) if warnings: with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f: for w in warnings: f.write(w + "\n") print(f"Wrote {out_leaderboard_path}") print(f"Wrote {out_audit_path}") print(f"Wrote {out_dim_path}") if warnings: print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}") if __name__ == "__main__": main()