Add compare_models.py for model evaluation with primary metrics and age-bin validation
This commit is contained in:
427
compare_models.py
Normal file
427
compare_models.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from utils import load_train_config
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS)
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0)
|
||||||
|
PRIMARY_CAPTURE_TOPKS = (20, 50)
|
||||||
|
PRIMARY_WY_TAU = 1.0
|
||||||
|
PRIMARY_WY_FRACS = (0.05, 0.10)
|
||||||
|
|
||||||
|
PRIMARY_METRICS = (
|
||||||
|
("C1", "capture", {"topk": 20, "tau": 1.0}),
|
||||||
|
("C2", "capture", {"topk": 20, "tau": 2.0}),
|
||||||
|
("C3", "capture", {"topk": 20, "tau": 5.0}),
|
||||||
|
("C4", "capture", {"topk": 50, "tau": 1.0}),
|
||||||
|
("C5", "capture", {"topk": 50, "tau": 2.0}),
|
||||||
|
("C6", "capture", {"topk": 50, "tau": 5.0}),
|
||||||
|
("W1", "workload", {"frac": 0.05, "tau": 1.0}),
|
||||||
|
("W2", "workload", {"frac": 0.10, "tau": 1.0}),
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_AGE_BINS = (
|
||||||
|
"[40.0, 45.0)",
|
||||||
|
"[45.0, 50.0)",
|
||||||
|
"[50.0, 55.0)",
|
||||||
|
"[55.0, 60.0)",
|
||||||
|
"[60.0, 65.0)",
|
||||||
|
"[65.0, 70.0)",
|
||||||
|
"[70.0, inf)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Compare trained models using 8 pre-registered primary metrics."
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--runs_root",
|
||||||
|
type=str,
|
||||||
|
default="runs",
|
||||||
|
help="Directory containing run subfolders (default: runs)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--pattern",
|
||||||
|
type=str,
|
||||||
|
default="*",
|
||||||
|
help="Glob pattern for run folder names under runs_root (default: *)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--run_dirs_file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Optional file with one run_dir per line (overrides runs_root/pattern)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--out_dir",
|
||||||
|
type=str,
|
||||||
|
default="comparison",
|
||||||
|
help="Output directory for tables (default: comparison)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--expect_n",
|
||||||
|
type=int,
|
||||||
|
default=12,
|
||||||
|
help="Expected number of models (default: 12). Set 0 to disable check.",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]:
|
||||||
|
if not os.path.isdir(runs_root):
|
||||||
|
raise FileNotFoundError(f"runs_root not found: {runs_root}")
|
||||||
|
|
||||||
|
import glob
|
||||||
|
|
||||||
|
out: List[str] = []
|
||||||
|
for d in glob.glob(os.path.join(runs_root, pattern)):
|
||||||
|
if not os.path.isdir(d):
|
||||||
|
continue
|
||||||
|
if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))):
|
||||||
|
continue
|
||||||
|
out.append(d)
|
||||||
|
out.sort()
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _read_run_dirs_file(path: str) -> List[str]:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileNotFoundError(f"run_dirs_file not found: {path}")
|
||||||
|
out: List[str] = []
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
out.append(line)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool:
|
||||||
|
return abs(float(a) - float(b)) <= tol
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]:
|
||||||
|
"""Returns (event_weighted, macro, total_denom_events)."""
|
||||||
|
sub = df[(df["topk"] == int(topk)) & (
|
||||||
|
df["tau_years"].astype(float) == float(tau))].copy()
|
||||||
|
if sub.empty:
|
||||||
|
return float("nan"), float("nan"), 0
|
||||||
|
|
||||||
|
denom = sub["denom_events"].astype(np.int64)
|
||||||
|
numer = sub["numer_events"].astype(np.int64)
|
||||||
|
|
||||||
|
total_denom = int(denom.sum())
|
||||||
|
if total_denom > 0:
|
||||||
|
event_weighted = float(numer.sum() / total_denom)
|
||||||
|
else:
|
||||||
|
event_weighted = float("nan")
|
||||||
|
|
||||||
|
sub_pos = sub[denom > 0]
|
||||||
|
macro = float(sub_pos["capture_at_k"].astype(
|
||||||
|
float).mean()) if not sub_pos.empty else float("nan")
|
||||||
|
return event_weighted, macro, total_denom
|
||||||
|
|
||||||
|
|
||||||
|
def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]:
|
||||||
|
"""Returns (event_weighted, macro, total_events)."""
|
||||||
|
sub = df[(df["tau_years"].astype(float) == float(tau)) & (
|
||||||
|
df["frac_selected"].astype(float) == float(frac))].copy()
|
||||||
|
if sub.empty:
|
||||||
|
return float("nan"), float("nan"), 0
|
||||||
|
|
||||||
|
total_events = sub["total_events"].astype(np.int64)
|
||||||
|
captured = sub["events_captured"].astype(np.int64)
|
||||||
|
|
||||||
|
denom = int(total_events.sum())
|
||||||
|
if denom > 0:
|
||||||
|
event_weighted = float(captured.sum() / denom)
|
||||||
|
else:
|
||||||
|
event_weighted = float("nan")
|
||||||
|
|
||||||
|
sub_pos = sub[total_events > 0]
|
||||||
|
macro = float(sub_pos["capture_rate"].astype(
|
||||||
|
float).mean()) if not sub_pos.empty else float("nan")
|
||||||
|
return event_weighted, macro, denom
|
||||||
|
|
||||||
|
|
||||||
|
def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame:
|
||||||
|
"""Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name."""
|
||||||
|
df = pd.DataFrame({"model_name": model_names,
|
||||||
|
"value": values.astype(float)})
|
||||||
|
df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64))
|
||||||
|
df = df.sort_values(["is_nan", "value", "model_name"], ascending=[
|
||||||
|
True, False, True], kind="mergesort")
|
||||||
|
df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64)
|
||||||
|
n = len(df)
|
||||||
|
df["points"] = (n - df["rank"] + 1).astype(np.int64)
|
||||||
|
# Force NaNs to last (1 point)
|
||||||
|
df.loc[df["is_nan"], "points"] = 1
|
||||||
|
return df[["model_name", "rank", "points"]]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.run_dirs_file:
|
||||||
|
run_dirs = _read_run_dirs_file(args.run_dirs_file)
|
||||||
|
else:
|
||||||
|
run_dirs = _discover_run_dirs(args.runs_root, args.pattern)
|
||||||
|
|
||||||
|
if args.expect_n and len(run_dirs) != int(args.expect_n):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected {args.expect_n} runs, found {len(run_dirs)}")
|
||||||
|
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
rows: List[Dict[str, object]] = []
|
||||||
|
warnings: List[str] = []
|
||||||
|
|
||||||
|
for run_dir in run_dirs:
|
||||||
|
model_name = os.path.basename(os.path.normpath(run_dir))
|
||||||
|
|
||||||
|
cfg = load_train_config(run_dir)
|
||||||
|
loss_type = str(cfg.get("loss_type", ""))
|
||||||
|
age_encoder = str(
|
||||||
|
cfg.get("age_encoder", cfg.get("age_encoder_type", "")))
|
||||||
|
full_cov = bool(cfg.get("full_cov", False))
|
||||||
|
covariate_setting = "fullcov" if full_cov else "partcov"
|
||||||
|
|
||||||
|
cap_path = os.path.join(run_dir, "horizon_capture.csv")
|
||||||
|
wy_path = os.path.join(run_dir, "workload_yield.csv")
|
||||||
|
|
||||||
|
if not os.path.isfile(cap_path) or not os.path.isfile(wy_path):
|
||||||
|
warnings.append(
|
||||||
|
f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv")
|
||||||
|
cap_df = pd.DataFrame()
|
||||||
|
wy_df = pd.DataFrame()
|
||||||
|
else:
|
||||||
|
cap_df = pd.read_csv(cap_path)
|
||||||
|
wy_df = pd.read_csv(wy_path)
|
||||||
|
|
||||||
|
# Validate age bins (best-effort). Missing bins are allowed (no eligible records),
|
||||||
|
# but unexpected bins likely indicate protocol mismatch.
|
||||||
|
if not cap_df.empty and "age_bin" in cap_df.columns:
|
||||||
|
seen_bins = sorted(set(str(x)
|
||||||
|
for x in cap_df["age_bin"].dropna().unique()))
|
||||||
|
unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS]
|
||||||
|
if unexpected:
|
||||||
|
warnings.append(
|
||||||
|
f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}")
|
||||||
|
|
||||||
|
out: Dict[str, object] = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"run_dir": run_dir,
|
||||||
|
"loss_type": loss_type,
|
||||||
|
"age_encoder": age_encoder,
|
||||||
|
"covariate_setting": covariate_setting,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Primary panel: event-weighted metrics
|
||||||
|
for metric_id, kind, spec in PRIMARY_METRICS:
|
||||||
|
if kind == "capture":
|
||||||
|
v_w, v_macro, denom = _capture_values(
|
||||||
|
cap_df, topk=spec["topk"], tau=spec["tau"])
|
||||||
|
out[metric_id] = v_w
|
||||||
|
out[f"{metric_id}_macro"] = v_macro
|
||||||
|
out[f"{metric_id}_denom"] = int(denom)
|
||||||
|
elif kind == "workload":
|
||||||
|
v_w, v_macro, denom = _workload_values(
|
||||||
|
wy_df, frac=spec["frac"], tau=spec["tau"])
|
||||||
|
out[metric_id] = v_w
|
||||||
|
out[f"{metric_id}_macro"] = v_macro
|
||||||
|
out[f"{metric_id}_denom"] = int(denom)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown metric kind: {kind}")
|
||||||
|
|
||||||
|
rows.append(out)
|
||||||
|
|
||||||
|
metrics_df = pd.DataFrame(rows)
|
||||||
|
|
||||||
|
# Scoring: ranks + points for the 8 primary metrics (event-weighted)
|
||||||
|
per_metric_rank_parts: List[pd.DataFrame] = []
|
||||||
|
total_points = np.zeros((len(metrics_df),), dtype=np.int64)
|
||||||
|
|
||||||
|
for metric_id, _, _ in PRIMARY_METRICS:
|
||||||
|
r = _rank_points(metrics_df[metric_id], metrics_df["model_name"])
|
||||||
|
r = r.rename(columns={"rank": f"{metric_id}_rank",
|
||||||
|
"points": f"{metric_id}_points"})
|
||||||
|
per_metric_rank_parts.append(r)
|
||||||
|
|
||||||
|
ranks_df = per_metric_rank_parts[0]
|
||||||
|
for part in per_metric_rank_parts[1:]:
|
||||||
|
ranks_df = ranks_df.merge(part, on="model_name", how="left")
|
||||||
|
|
||||||
|
# Merge ranks back
|
||||||
|
merged = metrics_df.merge(ranks_df, on="model_name", how="left")
|
||||||
|
|
||||||
|
point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS]
|
||||||
|
merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64)
|
||||||
|
|
||||||
|
# Macro robustness score (not used for ranking, but used to flag instability)
|
||||||
|
macro_point_parts: List[pd.DataFrame] = []
|
||||||
|
for metric_id, _, _ in PRIMARY_METRICS:
|
||||||
|
r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"])
|
||||||
|
r = r.rename(columns={"rank": f"{metric_id}_macro_rank",
|
||||||
|
"points": f"{metric_id}_macro_points"})
|
||||||
|
macro_point_parts.append(r)
|
||||||
|
|
||||||
|
macro_ranks_df = macro_point_parts[0]
|
||||||
|
for part in macro_point_parts[1:]:
|
||||||
|
macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left")
|
||||||
|
|
||||||
|
merged = merged.merge(macro_ranks_df, on="model_name", how="left")
|
||||||
|
macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS]
|
||||||
|
merged["macro_total_score"] = merged[macro_point_cols].sum(
|
||||||
|
axis=1).astype(np.int64)
|
||||||
|
|
||||||
|
# Final leaderboard ranks
|
||||||
|
merged = merged.sort_values(["total_score", "model_name"], ascending=[
|
||||||
|
False, True], kind="mergesort")
|
||||||
|
merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64)
|
||||||
|
|
||||||
|
merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[
|
||||||
|
False, True], kind="mergesort")
|
||||||
|
macro_rank_map = {
|
||||||
|
name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())}
|
||||||
|
merged["macro_leaderboard_rank"] = merged["model_name"].map(
|
||||||
|
macro_rank_map).astype(np.int64)
|
||||||
|
merged["rank_delta_macro_minus_primary"] = (
|
||||||
|
merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64)
|
||||||
|
merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4
|
||||||
|
|
||||||
|
# Output tables
|
||||||
|
leaderboard_cols = [
|
||||||
|
"leaderboard_rank",
|
||||||
|
"model_name",
|
||||||
|
"loss_type",
|
||||||
|
"age_encoder",
|
||||||
|
"covariate_setting",
|
||||||
|
"total_score",
|
||||||
|
"macro_total_score",
|
||||||
|
"macro_leaderboard_rank",
|
||||||
|
"rank_delta_macro_minus_primary",
|
||||||
|
"macro_collapse_flag",
|
||||||
|
]
|
||||||
|
for mid, _, _ in PRIMARY_METRICS:
|
||||||
|
leaderboard_cols.append(mid)
|
||||||
|
out_leaderboard = merged[leaderboard_cols]
|
||||||
|
|
||||||
|
out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv")
|
||||||
|
out_leaderboard.to_csv(out_leaderboard_path, index=False)
|
||||||
|
|
||||||
|
# Audit table: per-metric ranks/points + values
|
||||||
|
audit_cols = ["model_name"]
|
||||||
|
for mid, _, _ in PRIMARY_METRICS:
|
||||||
|
audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"])
|
||||||
|
audit_cols.append("total_score")
|
||||||
|
out_audit = merged[audit_cols]
|
||||||
|
|
||||||
|
out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv")
|
||||||
|
out_audit.to_csv(out_audit_path, index=False)
|
||||||
|
|
||||||
|
# Dimension summaries
|
||||||
|
dim_summary_parts: List[pd.DataFrame] = []
|
||||||
|
for dim in ["loss_type", "age_encoder", "covariate_setting"]:
|
||||||
|
g = merged.groupby(dim, dropna=False)
|
||||||
|
dim_summary_parts.append(
|
||||||
|
g.agg(
|
||||||
|
n_models=("model_name", "count"),
|
||||||
|
mean_total_score=("total_score", "mean"),
|
||||||
|
std_total_score=("total_score", "std"),
|
||||||
|
best_total_score=("total_score", "max"),
|
||||||
|
).reset_index()
|
||||||
|
)
|
||||||
|
|
||||||
|
dim_summary = pd.concat(
|
||||||
|
[df.assign(dimension=name) for df, name in zip(dim_summary_parts, [
|
||||||
|
"loss_type", "age_encoder", "covariate_setting"])],
|
||||||
|
ignore_index=True,
|
||||||
|
)
|
||||||
|
out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv")
|
||||||
|
dim_summary.to_csv(out_dim_path, index=False)
|
||||||
|
|
||||||
|
# Matched-pair deltas for interpretability
|
||||||
|
def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame:
|
||||||
|
cols = list(fixed) + [var, "model_name", "total_score"]
|
||||||
|
df = merged[cols].copy()
|
||||||
|
pairs: List[Dict[str, object]] = []
|
||||||
|
for key, grp in df.groupby(list(fixed), dropna=False):
|
||||||
|
if len(grp) < 2:
|
||||||
|
continue
|
||||||
|
if grp[var].nunique() < 2:
|
||||||
|
continue
|
||||||
|
# deterministic ordering
|
||||||
|
grp = grp.sort_values([var, "model_name"], kind="mergesort")
|
||||||
|
base = grp.iloc[0]
|
||||||
|
for i in range(1, len(grp)):
|
||||||
|
other = grp.iloc[i]
|
||||||
|
pairs.append(
|
||||||
|
{
|
||||||
|
**({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}),
|
||||||
|
"var_dim": var,
|
||||||
|
"a_var": base[var],
|
||||||
|
"b_var": other[var],
|
||||||
|
"a_model": base["model_name"],
|
||||||
|
"b_model": other["model_name"],
|
||||||
|
"a_score": int(base["total_score"]),
|
||||||
|
"b_score": int(other["total_score"]),
|
||||||
|
"delta_b_minus_a": int(other["total_score"] - base["total_score"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(pairs)
|
||||||
|
|
||||||
|
out_pairs_cov = _pair_deltas(
|
||||||
|
["loss_type", "age_encoder"], "covariate_setting")
|
||||||
|
out_pairs_age = _pair_deltas(
|
||||||
|
["loss_type", "covariate_setting"], "age_encoder")
|
||||||
|
out_pairs_loss = _pair_deltas(
|
||||||
|
["age_encoder", "covariate_setting"], "loss_type")
|
||||||
|
|
||||||
|
out_pairs_cov.to_csv(os.path.join(
|
||||||
|
args.out_dir, "pairs_covariates.csv"), index=False)
|
||||||
|
out_pairs_age.to_csv(os.path.join(
|
||||||
|
args.out_dir, "pairs_age_encoder.csv"), index=False)
|
||||||
|
out_pairs_loss.to_csv(os.path.join(
|
||||||
|
args.out_dir, "pairs_loss_type.csv"), index=False)
|
||||||
|
|
||||||
|
# Protocol stamp for auditability
|
||||||
|
protocol = {
|
||||||
|
"PRIMARY_METRICS": list(PRIMARY_METRICS),
|
||||||
|
"EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS),
|
||||||
|
"PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS),
|
||||||
|
"PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS),
|
||||||
|
"PRIMARY_WY_TAU": float(PRIMARY_WY_TAU),
|
||||||
|
"PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS),
|
||||||
|
"aggregation_primary": "event-weighted by denom_events / total_events",
|
||||||
|
"aggregation_macro": "macro-average over age bins with denom>0",
|
||||||
|
"nan_policy": "NaN ranked last (1 point)",
|
||||||
|
"tie_break": "model_name ascending (deterministic)",
|
||||||
|
}
|
||||||
|
pd.Series(protocol).to_json(os.path.join(args.out_dir,
|
||||||
|
"protocol.json"), indent=2, force_ascii=False)
|
||||||
|
|
||||||
|
if warnings:
|
||||||
|
with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f:
|
||||||
|
for w in warnings:
|
||||||
|
f.write(w + "\n")
|
||||||
|
|
||||||
|
print(f"Wrote {out_leaderboard_path}")
|
||||||
|
print(f"Wrote {out_audit_path}")
|
||||||
|
print(f"Wrote {out_dim_path}")
|
||||||
|
if warnings:
|
||||||
|
print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user