Add compare_models.py for model evaluation with primary metrics and age-bin validation
This commit is contained in:
427
compare_models.py
Normal file
427
compare_models.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from utils import load_train_config
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS)
|
||||
# -------------------------
|
||||
|
||||
PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0)
|
||||
PRIMARY_CAPTURE_TOPKS = (20, 50)
|
||||
PRIMARY_WY_TAU = 1.0
|
||||
PRIMARY_WY_FRACS = (0.05, 0.10)
|
||||
|
||||
PRIMARY_METRICS = (
|
||||
("C1", "capture", {"topk": 20, "tau": 1.0}),
|
||||
("C2", "capture", {"topk": 20, "tau": 2.0}),
|
||||
("C3", "capture", {"topk": 20, "tau": 5.0}),
|
||||
("C4", "capture", {"topk": 50, "tau": 1.0}),
|
||||
("C5", "capture", {"topk": 50, "tau": 2.0}),
|
||||
("C6", "capture", {"topk": 50, "tau": 5.0}),
|
||||
("W1", "workload", {"frac": 0.05, "tau": 1.0}),
|
||||
("W2", "workload", {"frac": 0.10, "tau": 1.0}),
|
||||
)
|
||||
|
||||
EXPECTED_AGE_BINS = (
|
||||
"[40.0, 45.0)",
|
||||
"[45.0, 50.0)",
|
||||
"[50.0, 55.0)",
|
||||
"[55.0, 60.0)",
|
||||
"[60.0, 65.0)",
|
||||
"[65.0, 70.0)",
|
||||
"[70.0, inf)",
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Compare trained models using 8 pre-registered primary metrics."
|
||||
)
|
||||
p.add_argument(
|
||||
"--runs_root",
|
||||
type=str,
|
||||
default="runs",
|
||||
help="Directory containing run subfolders (default: runs)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
default="*",
|
||||
help="Glob pattern for run folder names under runs_root (default: *)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--run_dirs_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional file with one run_dir per line (overrides runs_root/pattern)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--out_dir",
|
||||
type=str,
|
||||
default="comparison",
|
||||
help="Output directory for tables (default: comparison)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--expect_n",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Expected number of models (default: 12). Set 0 to disable check.",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]:
|
||||
if not os.path.isdir(runs_root):
|
||||
raise FileNotFoundError(f"runs_root not found: {runs_root}")
|
||||
|
||||
import glob
|
||||
|
||||
out: List[str] = []
|
||||
for d in glob.glob(os.path.join(runs_root, pattern)):
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))):
|
||||
continue
|
||||
out.append(d)
|
||||
out.sort()
|
||||
return out
|
||||
|
||||
|
||||
def _read_run_dirs_file(path: str) -> List[str]:
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"run_dirs_file not found: {path}")
|
||||
out: List[str] = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
out.append(line)
|
||||
return out
|
||||
|
||||
|
||||
def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool:
|
||||
return abs(float(a) - float(b)) <= tol
|
||||
|
||||
|
||||
def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]:
|
||||
"""Returns (event_weighted, macro, total_denom_events)."""
|
||||
sub = df[(df["topk"] == int(topk)) & (
|
||||
df["tau_years"].astype(float) == float(tau))].copy()
|
||||
if sub.empty:
|
||||
return float("nan"), float("nan"), 0
|
||||
|
||||
denom = sub["denom_events"].astype(np.int64)
|
||||
numer = sub["numer_events"].astype(np.int64)
|
||||
|
||||
total_denom = int(denom.sum())
|
||||
if total_denom > 0:
|
||||
event_weighted = float(numer.sum() / total_denom)
|
||||
else:
|
||||
event_weighted = float("nan")
|
||||
|
||||
sub_pos = sub[denom > 0]
|
||||
macro = float(sub_pos["capture_at_k"].astype(
|
||||
float).mean()) if not sub_pos.empty else float("nan")
|
||||
return event_weighted, macro, total_denom
|
||||
|
||||
|
||||
def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]:
|
||||
"""Returns (event_weighted, macro, total_events)."""
|
||||
sub = df[(df["tau_years"].astype(float) == float(tau)) & (
|
||||
df["frac_selected"].astype(float) == float(frac))].copy()
|
||||
if sub.empty:
|
||||
return float("nan"), float("nan"), 0
|
||||
|
||||
total_events = sub["total_events"].astype(np.int64)
|
||||
captured = sub["events_captured"].astype(np.int64)
|
||||
|
||||
denom = int(total_events.sum())
|
||||
if denom > 0:
|
||||
event_weighted = float(captured.sum() / denom)
|
||||
else:
|
||||
event_weighted = float("nan")
|
||||
|
||||
sub_pos = sub[total_events > 0]
|
||||
macro = float(sub_pos["capture_rate"].astype(
|
||||
float).mean()) if not sub_pos.empty else float("nan")
|
||||
return event_weighted, macro, denom
|
||||
|
||||
|
||||
def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame:
|
||||
"""Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name."""
|
||||
df = pd.DataFrame({"model_name": model_names,
|
||||
"value": values.astype(float)})
|
||||
df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64))
|
||||
df = df.sort_values(["is_nan", "value", "model_name"], ascending=[
|
||||
True, False, True], kind="mergesort")
|
||||
df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64)
|
||||
n = len(df)
|
||||
df["points"] = (n - df["rank"] + 1).astype(np.int64)
|
||||
# Force NaNs to last (1 point)
|
||||
df.loc[df["is_nan"], "points"] = 1
|
||||
return df[["model_name", "rank", "points"]]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.run_dirs_file:
|
||||
run_dirs = _read_run_dirs_file(args.run_dirs_file)
|
||||
else:
|
||||
run_dirs = _discover_run_dirs(args.runs_root, args.pattern)
|
||||
|
||||
if args.expect_n and len(run_dirs) != int(args.expect_n):
|
||||
raise ValueError(
|
||||
f"Expected {args.expect_n} runs, found {len(run_dirs)}")
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
rows: List[Dict[str, object]] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
for run_dir in run_dirs:
|
||||
model_name = os.path.basename(os.path.normpath(run_dir))
|
||||
|
||||
cfg = load_train_config(run_dir)
|
||||
loss_type = str(cfg.get("loss_type", ""))
|
||||
age_encoder = str(
|
||||
cfg.get("age_encoder", cfg.get("age_encoder_type", "")))
|
||||
full_cov = bool(cfg.get("full_cov", False))
|
||||
covariate_setting = "fullcov" if full_cov else "partcov"
|
||||
|
||||
cap_path = os.path.join(run_dir, "horizon_capture.csv")
|
||||
wy_path = os.path.join(run_dir, "workload_yield.csv")
|
||||
|
||||
if not os.path.isfile(cap_path) or not os.path.isfile(wy_path):
|
||||
warnings.append(
|
||||
f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv")
|
||||
cap_df = pd.DataFrame()
|
||||
wy_df = pd.DataFrame()
|
||||
else:
|
||||
cap_df = pd.read_csv(cap_path)
|
||||
wy_df = pd.read_csv(wy_path)
|
||||
|
||||
# Validate age bins (best-effort). Missing bins are allowed (no eligible records),
|
||||
# but unexpected bins likely indicate protocol mismatch.
|
||||
if not cap_df.empty and "age_bin" in cap_df.columns:
|
||||
seen_bins = sorted(set(str(x)
|
||||
for x in cap_df["age_bin"].dropna().unique()))
|
||||
unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS]
|
||||
if unexpected:
|
||||
warnings.append(
|
||||
f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}")
|
||||
|
||||
out: Dict[str, object] = {
|
||||
"model_name": model_name,
|
||||
"run_dir": run_dir,
|
||||
"loss_type": loss_type,
|
||||
"age_encoder": age_encoder,
|
||||
"covariate_setting": covariate_setting,
|
||||
}
|
||||
|
||||
# Primary panel: event-weighted metrics
|
||||
for metric_id, kind, spec in PRIMARY_METRICS:
|
||||
if kind == "capture":
|
||||
v_w, v_macro, denom = _capture_values(
|
||||
cap_df, topk=spec["topk"], tau=spec["tau"])
|
||||
out[metric_id] = v_w
|
||||
out[f"{metric_id}_macro"] = v_macro
|
||||
out[f"{metric_id}_denom"] = int(denom)
|
||||
elif kind == "workload":
|
||||
v_w, v_macro, denom = _workload_values(
|
||||
wy_df, frac=spec["frac"], tau=spec["tau"])
|
||||
out[metric_id] = v_w
|
||||
out[f"{metric_id}_macro"] = v_macro
|
||||
out[f"{metric_id}_denom"] = int(denom)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric kind: {kind}")
|
||||
|
||||
rows.append(out)
|
||||
|
||||
metrics_df = pd.DataFrame(rows)
|
||||
|
||||
# Scoring: ranks + points for the 8 primary metrics (event-weighted)
|
||||
per_metric_rank_parts: List[pd.DataFrame] = []
|
||||
total_points = np.zeros((len(metrics_df),), dtype=np.int64)
|
||||
|
||||
for metric_id, _, _ in PRIMARY_METRICS:
|
||||
r = _rank_points(metrics_df[metric_id], metrics_df["model_name"])
|
||||
r = r.rename(columns={"rank": f"{metric_id}_rank",
|
||||
"points": f"{metric_id}_points"})
|
||||
per_metric_rank_parts.append(r)
|
||||
|
||||
ranks_df = per_metric_rank_parts[0]
|
||||
for part in per_metric_rank_parts[1:]:
|
||||
ranks_df = ranks_df.merge(part, on="model_name", how="left")
|
||||
|
||||
# Merge ranks back
|
||||
merged = metrics_df.merge(ranks_df, on="model_name", how="left")
|
||||
|
||||
point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS]
|
||||
merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64)
|
||||
|
||||
# Macro robustness score (not used for ranking, but used to flag instability)
|
||||
macro_point_parts: List[pd.DataFrame] = []
|
||||
for metric_id, _, _ in PRIMARY_METRICS:
|
||||
r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"])
|
||||
r = r.rename(columns={"rank": f"{metric_id}_macro_rank",
|
||||
"points": f"{metric_id}_macro_points"})
|
||||
macro_point_parts.append(r)
|
||||
|
||||
macro_ranks_df = macro_point_parts[0]
|
||||
for part in macro_point_parts[1:]:
|
||||
macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left")
|
||||
|
||||
merged = merged.merge(macro_ranks_df, on="model_name", how="left")
|
||||
macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS]
|
||||
merged["macro_total_score"] = merged[macro_point_cols].sum(
|
||||
axis=1).astype(np.int64)
|
||||
|
||||
# Final leaderboard ranks
|
||||
merged = merged.sort_values(["total_score", "model_name"], ascending=[
|
||||
False, True], kind="mergesort")
|
||||
merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64)
|
||||
|
||||
merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[
|
||||
False, True], kind="mergesort")
|
||||
macro_rank_map = {
|
||||
name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())}
|
||||
merged["macro_leaderboard_rank"] = merged["model_name"].map(
|
||||
macro_rank_map).astype(np.int64)
|
||||
merged["rank_delta_macro_minus_primary"] = (
|
||||
merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64)
|
||||
merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4
|
||||
|
||||
# Output tables
|
||||
leaderboard_cols = [
|
||||
"leaderboard_rank",
|
||||
"model_name",
|
||||
"loss_type",
|
||||
"age_encoder",
|
||||
"covariate_setting",
|
||||
"total_score",
|
||||
"macro_total_score",
|
||||
"macro_leaderboard_rank",
|
||||
"rank_delta_macro_minus_primary",
|
||||
"macro_collapse_flag",
|
||||
]
|
||||
for mid, _, _ in PRIMARY_METRICS:
|
||||
leaderboard_cols.append(mid)
|
||||
out_leaderboard = merged[leaderboard_cols]
|
||||
|
||||
out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv")
|
||||
out_leaderboard.to_csv(out_leaderboard_path, index=False)
|
||||
|
||||
# Audit table: per-metric ranks/points + values
|
||||
audit_cols = ["model_name"]
|
||||
for mid, _, _ in PRIMARY_METRICS:
|
||||
audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"])
|
||||
audit_cols.append("total_score")
|
||||
out_audit = merged[audit_cols]
|
||||
|
||||
out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv")
|
||||
out_audit.to_csv(out_audit_path, index=False)
|
||||
|
||||
# Dimension summaries
|
||||
dim_summary_parts: List[pd.DataFrame] = []
|
||||
for dim in ["loss_type", "age_encoder", "covariate_setting"]:
|
||||
g = merged.groupby(dim, dropna=False)
|
||||
dim_summary_parts.append(
|
||||
g.agg(
|
||||
n_models=("model_name", "count"),
|
||||
mean_total_score=("total_score", "mean"),
|
||||
std_total_score=("total_score", "std"),
|
||||
best_total_score=("total_score", "max"),
|
||||
).reset_index()
|
||||
)
|
||||
|
||||
dim_summary = pd.concat(
|
||||
[df.assign(dimension=name) for df, name in zip(dim_summary_parts, [
|
||||
"loss_type", "age_encoder", "covariate_setting"])],
|
||||
ignore_index=True,
|
||||
)
|
||||
out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv")
|
||||
dim_summary.to_csv(out_dim_path, index=False)
|
||||
|
||||
# Matched-pair deltas for interpretability
|
||||
def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame:
|
||||
cols = list(fixed) + [var, "model_name", "total_score"]
|
||||
df = merged[cols].copy()
|
||||
pairs: List[Dict[str, object]] = []
|
||||
for key, grp in df.groupby(list(fixed), dropna=False):
|
||||
if len(grp) < 2:
|
||||
continue
|
||||
if grp[var].nunique() < 2:
|
||||
continue
|
||||
# deterministic ordering
|
||||
grp = grp.sort_values([var, "model_name"], kind="mergesort")
|
||||
base = grp.iloc[0]
|
||||
for i in range(1, len(grp)):
|
||||
other = grp.iloc[i]
|
||||
pairs.append(
|
||||
{
|
||||
**({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}),
|
||||
"var_dim": var,
|
||||
"a_var": base[var],
|
||||
"b_var": other[var],
|
||||
"a_model": base["model_name"],
|
||||
"b_model": other["model_name"],
|
||||
"a_score": int(base["total_score"]),
|
||||
"b_score": int(other["total_score"]),
|
||||
"delta_b_minus_a": int(other["total_score"] - base["total_score"]),
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(pairs)
|
||||
|
||||
out_pairs_cov = _pair_deltas(
|
||||
["loss_type", "age_encoder"], "covariate_setting")
|
||||
out_pairs_age = _pair_deltas(
|
||||
["loss_type", "covariate_setting"], "age_encoder")
|
||||
out_pairs_loss = _pair_deltas(
|
||||
["age_encoder", "covariate_setting"], "loss_type")
|
||||
|
||||
out_pairs_cov.to_csv(os.path.join(
|
||||
args.out_dir, "pairs_covariates.csv"), index=False)
|
||||
out_pairs_age.to_csv(os.path.join(
|
||||
args.out_dir, "pairs_age_encoder.csv"), index=False)
|
||||
out_pairs_loss.to_csv(os.path.join(
|
||||
args.out_dir, "pairs_loss_type.csv"), index=False)
|
||||
|
||||
# Protocol stamp for auditability
|
||||
protocol = {
|
||||
"PRIMARY_METRICS": list(PRIMARY_METRICS),
|
||||
"EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS),
|
||||
"PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS),
|
||||
"PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS),
|
||||
"PRIMARY_WY_TAU": float(PRIMARY_WY_TAU),
|
||||
"PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS),
|
||||
"aggregation_primary": "event-weighted by denom_events / total_events",
|
||||
"aggregation_macro": "macro-average over age bins with denom>0",
|
||||
"nan_policy": "NaN ranked last (1 point)",
|
||||
"tie_break": "model_name ascending (deterministic)",
|
||||
}
|
||||
pd.Series(protocol).to_json(os.path.join(args.out_dir,
|
||||
"protocol.json"), indent=2, force_ascii=False)
|
||||
|
||||
if warnings:
|
||||
with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f:
|
||||
for w in warnings:
|
||||
f.write(w + "\n")
|
||||
|
||||
print(f"Wrote {out_leaderboard_path}")
|
||||
print(f"Wrote {out_audit_path}")
|
||||
print(f"Wrote {out_dim_path}")
|
||||
if warnings:
|
||||
print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user