diff --git a/compare_models.py b/compare_models.py new file mode 100644 index 0000000..a28e1c0 --- /dev/null +++ b/compare_models.py @@ -0,0 +1,427 @@ +import argparse +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd + +from utils import load_train_config + + +# ------------------------- +# Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS) +# ------------------------- + +PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0) +PRIMARY_CAPTURE_TOPKS = (20, 50) +PRIMARY_WY_TAU = 1.0 +PRIMARY_WY_FRACS = (0.05, 0.10) + +PRIMARY_METRICS = ( + ("C1", "capture", {"topk": 20, "tau": 1.0}), + ("C2", "capture", {"topk": 20, "tau": 2.0}), + ("C3", "capture", {"topk": 20, "tau": 5.0}), + ("C4", "capture", {"topk": 50, "tau": 1.0}), + ("C5", "capture", {"topk": 50, "tau": 2.0}), + ("C6", "capture", {"topk": 50, "tau": 5.0}), + ("W1", "workload", {"frac": 0.05, "tau": 1.0}), + ("W2", "workload", {"frac": 0.10, "tau": 1.0}), +) + +EXPECTED_AGE_BINS = ( + "[40.0, 45.0)", + "[45.0, 50.0)", + "[50.0, 55.0)", + "[55.0, 60.0)", + "[60.0, 65.0)", + "[65.0, 70.0)", + "[70.0, inf)", +) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Compare trained models using 8 pre-registered primary metrics." + ) + p.add_argument( + "--runs_root", + type=str, + default="runs", + help="Directory containing run subfolders (default: runs)", + ) + p.add_argument( + "--pattern", + type=str, + default="*", + help="Glob pattern for run folder names under runs_root (default: *)", + ) + p.add_argument( + "--run_dirs_file", + type=str, + default="", + help="Optional file with one run_dir per line (overrides runs_root/pattern)", + ) + p.add_argument( + "--out_dir", + type=str, + default="comparison", + help="Output directory for tables (default: comparison)", + ) + p.add_argument( + "--expect_n", + type=int, + default=12, + help="Expected number of models (default: 12). Set 0 to disable check.", + ) + return p.parse_args() + + +def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]: + if not os.path.isdir(runs_root): + raise FileNotFoundError(f"runs_root not found: {runs_root}") + + import glob + + out: List[str] = [] + for d in glob.glob(os.path.join(runs_root, pattern)): + if not os.path.isdir(d): + continue + if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))): + continue + out.append(d) + out.sort() + return out + + +def _read_run_dirs_file(path: str) -> List[str]: + if not os.path.isfile(path): + raise FileNotFoundError(f"run_dirs_file not found: {path}") + out: List[str] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + out.append(line) + return out + + +def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool: + return abs(float(a) - float(b)) <= tol + + +def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]: + """Returns (event_weighted, macro, total_denom_events).""" + sub = df[(df["topk"] == int(topk)) & ( + df["tau_years"].astype(float) == float(tau))].copy() + if sub.empty: + return float("nan"), float("nan"), 0 + + denom = sub["denom_events"].astype(np.int64) + numer = sub["numer_events"].astype(np.int64) + + total_denom = int(denom.sum()) + if total_denom > 0: + event_weighted = float(numer.sum() / total_denom) + else: + event_weighted = float("nan") + + sub_pos = sub[denom > 0] + macro = float(sub_pos["capture_at_k"].astype( + float).mean()) if not sub_pos.empty else float("nan") + return event_weighted, macro, total_denom + + +def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]: + """Returns (event_weighted, macro, total_events).""" + sub = df[(df["tau_years"].astype(float) == float(tau)) & ( + df["frac_selected"].astype(float) == float(frac))].copy() + if sub.empty: + return float("nan"), float("nan"), 0 + + total_events = sub["total_events"].astype(np.int64) + captured = sub["events_captured"].astype(np.int64) + + denom = int(total_events.sum()) + if denom > 0: + event_weighted = float(captured.sum() / denom) + else: + event_weighted = float("nan") + + sub_pos = sub[total_events > 0] + macro = float(sub_pos["capture_rate"].astype( + float).mean()) if not sub_pos.empty else float("nan") + return event_weighted, macro, denom + + +def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame: + """Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name.""" + df = pd.DataFrame({"model_name": model_names, + "value": values.astype(float)}) + df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64)) + df = df.sort_values(["is_nan", "value", "model_name"], ascending=[ + True, False, True], kind="mergesort") + df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64) + n = len(df) + df["points"] = (n - df["rank"] + 1).astype(np.int64) + # Force NaNs to last (1 point) + df.loc[df["is_nan"], "points"] = 1 + return df[["model_name", "rank", "points"]] + + +def main() -> None: + args = parse_args() + + if args.run_dirs_file: + run_dirs = _read_run_dirs_file(args.run_dirs_file) + else: + run_dirs = _discover_run_dirs(args.runs_root, args.pattern) + + if args.expect_n and len(run_dirs) != int(args.expect_n): + raise ValueError( + f"Expected {args.expect_n} runs, found {len(run_dirs)}") + + os.makedirs(args.out_dir, exist_ok=True) + + rows: List[Dict[str, object]] = [] + warnings: List[str] = [] + + for run_dir in run_dirs: + model_name = os.path.basename(os.path.normpath(run_dir)) + + cfg = load_train_config(run_dir) + loss_type = str(cfg.get("loss_type", "")) + age_encoder = str( + cfg.get("age_encoder", cfg.get("age_encoder_type", ""))) + full_cov = bool(cfg.get("full_cov", False)) + covariate_setting = "fullcov" if full_cov else "partcov" + + cap_path = os.path.join(run_dir, "horizon_capture.csv") + wy_path = os.path.join(run_dir, "workload_yield.csv") + + if not os.path.isfile(cap_path) or not os.path.isfile(wy_path): + warnings.append( + f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv") + cap_df = pd.DataFrame() + wy_df = pd.DataFrame() + else: + cap_df = pd.read_csv(cap_path) + wy_df = pd.read_csv(wy_path) + + # Validate age bins (best-effort). Missing bins are allowed (no eligible records), + # but unexpected bins likely indicate protocol mismatch. + if not cap_df.empty and "age_bin" in cap_df.columns: + seen_bins = sorted(set(str(x) + for x in cap_df["age_bin"].dropna().unique())) + unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS] + if unexpected: + warnings.append( + f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}") + + out: Dict[str, object] = { + "model_name": model_name, + "run_dir": run_dir, + "loss_type": loss_type, + "age_encoder": age_encoder, + "covariate_setting": covariate_setting, + } + + # Primary panel: event-weighted metrics + for metric_id, kind, spec in PRIMARY_METRICS: + if kind == "capture": + v_w, v_macro, denom = _capture_values( + cap_df, topk=spec["topk"], tau=spec["tau"]) + out[metric_id] = v_w + out[f"{metric_id}_macro"] = v_macro + out[f"{metric_id}_denom"] = int(denom) + elif kind == "workload": + v_w, v_macro, denom = _workload_values( + wy_df, frac=spec["frac"], tau=spec["tau"]) + out[metric_id] = v_w + out[f"{metric_id}_macro"] = v_macro + out[f"{metric_id}_denom"] = int(denom) + else: + raise ValueError(f"Unknown metric kind: {kind}") + + rows.append(out) + + metrics_df = pd.DataFrame(rows) + + # Scoring: ranks + points for the 8 primary metrics (event-weighted) + per_metric_rank_parts: List[pd.DataFrame] = [] + total_points = np.zeros((len(metrics_df),), dtype=np.int64) + + for metric_id, _, _ in PRIMARY_METRICS: + r = _rank_points(metrics_df[metric_id], metrics_df["model_name"]) + r = r.rename(columns={"rank": f"{metric_id}_rank", + "points": f"{metric_id}_points"}) + per_metric_rank_parts.append(r) + + ranks_df = per_metric_rank_parts[0] + for part in per_metric_rank_parts[1:]: + ranks_df = ranks_df.merge(part, on="model_name", how="left") + + # Merge ranks back + merged = metrics_df.merge(ranks_df, on="model_name", how="left") + + point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS] + merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64) + + # Macro robustness score (not used for ranking, but used to flag instability) + macro_point_parts: List[pd.DataFrame] = [] + for metric_id, _, _ in PRIMARY_METRICS: + r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"]) + r = r.rename(columns={"rank": f"{metric_id}_macro_rank", + "points": f"{metric_id}_macro_points"}) + macro_point_parts.append(r) + + macro_ranks_df = macro_point_parts[0] + for part in macro_point_parts[1:]: + macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left") + + merged = merged.merge(macro_ranks_df, on="model_name", how="left") + macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS] + merged["macro_total_score"] = merged[macro_point_cols].sum( + axis=1).astype(np.int64) + + # Final leaderboard ranks + merged = merged.sort_values(["total_score", "model_name"], ascending=[ + False, True], kind="mergesort") + merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64) + + merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[ + False, True], kind="mergesort") + macro_rank_map = { + name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())} + merged["macro_leaderboard_rank"] = merged["model_name"].map( + macro_rank_map).astype(np.int64) + merged["rank_delta_macro_minus_primary"] = ( + merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64) + merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4 + + # Output tables + leaderboard_cols = [ + "leaderboard_rank", + "model_name", + "loss_type", + "age_encoder", + "covariate_setting", + "total_score", + "macro_total_score", + "macro_leaderboard_rank", + "rank_delta_macro_minus_primary", + "macro_collapse_flag", + ] + for mid, _, _ in PRIMARY_METRICS: + leaderboard_cols.append(mid) + out_leaderboard = merged[leaderboard_cols] + + out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv") + out_leaderboard.to_csv(out_leaderboard_path, index=False) + + # Audit table: per-metric ranks/points + values + audit_cols = ["model_name"] + for mid, _, _ in PRIMARY_METRICS: + audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"]) + audit_cols.append("total_score") + out_audit = merged[audit_cols] + + out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv") + out_audit.to_csv(out_audit_path, index=False) + + # Dimension summaries + dim_summary_parts: List[pd.DataFrame] = [] + for dim in ["loss_type", "age_encoder", "covariate_setting"]: + g = merged.groupby(dim, dropna=False) + dim_summary_parts.append( + g.agg( + n_models=("model_name", "count"), + mean_total_score=("total_score", "mean"), + std_total_score=("total_score", "std"), + best_total_score=("total_score", "max"), + ).reset_index() + ) + + dim_summary = pd.concat( + [df.assign(dimension=name) for df, name in zip(dim_summary_parts, [ + "loss_type", "age_encoder", "covariate_setting"])], + ignore_index=True, + ) + out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv") + dim_summary.to_csv(out_dim_path, index=False) + + # Matched-pair deltas for interpretability + def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame: + cols = list(fixed) + [var, "model_name", "total_score"] + df = merged[cols].copy() + pairs: List[Dict[str, object]] = [] + for key, grp in df.groupby(list(fixed), dropna=False): + if len(grp) < 2: + continue + if grp[var].nunique() < 2: + continue + # deterministic ordering + grp = grp.sort_values([var, "model_name"], kind="mergesort") + base = grp.iloc[0] + for i in range(1, len(grp)): + other = grp.iloc[i] + pairs.append( + { + **({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}), + "var_dim": var, + "a_var": base[var], + "b_var": other[var], + "a_model": base["model_name"], + "b_model": other["model_name"], + "a_score": int(base["total_score"]), + "b_score": int(other["total_score"]), + "delta_b_minus_a": int(other["total_score"] - base["total_score"]), + } + ) + return pd.DataFrame(pairs) + + out_pairs_cov = _pair_deltas( + ["loss_type", "age_encoder"], "covariate_setting") + out_pairs_age = _pair_deltas( + ["loss_type", "covariate_setting"], "age_encoder") + out_pairs_loss = _pair_deltas( + ["age_encoder", "covariate_setting"], "loss_type") + + out_pairs_cov.to_csv(os.path.join( + args.out_dir, "pairs_covariates.csv"), index=False) + out_pairs_age.to_csv(os.path.join( + args.out_dir, "pairs_age_encoder.csv"), index=False) + out_pairs_loss.to_csv(os.path.join( + args.out_dir, "pairs_loss_type.csv"), index=False) + + # Protocol stamp for auditability + protocol = { + "PRIMARY_METRICS": list(PRIMARY_METRICS), + "EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS), + "PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS), + "PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS), + "PRIMARY_WY_TAU": float(PRIMARY_WY_TAU), + "PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS), + "aggregation_primary": "event-weighted by denom_events / total_events", + "aggregation_macro": "macro-average over age bins with denom>0", + "nan_policy": "NaN ranked last (1 point)", + "tie_break": "model_name ascending (deterministic)", + } + pd.Series(protocol).to_json(os.path.join(args.out_dir, + "protocol.json"), indent=2, force_ascii=False) + + if warnings: + with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f: + for w in warnings: + f.write(w + "\n") + + print(f"Wrote {out_leaderboard_path}") + print(f"Wrote {out_audit_path}") + print(f"Wrote {out_dim_path}") + if warnings: + print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}") + + +if __name__ == "__main__": + main()