Add compare_models.py for model evaluation with primary metrics and age-bin validation

This commit is contained in:
2026-01-17 16:12:56 +08:00
parent fcd948818c
commit e56068e668

427
compare_models.py Normal file
View File

@@ -0,0 +1,427 @@
import argparse
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from utils import load_train_config
# -------------------------
# Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS)
# -------------------------
PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0)
PRIMARY_CAPTURE_TOPKS = (20, 50)
PRIMARY_WY_TAU = 1.0
PRIMARY_WY_FRACS = (0.05, 0.10)
PRIMARY_METRICS = (
("C1", "capture", {"topk": 20, "tau": 1.0}),
("C2", "capture", {"topk": 20, "tau": 2.0}),
("C3", "capture", {"topk": 20, "tau": 5.0}),
("C4", "capture", {"topk": 50, "tau": 1.0}),
("C5", "capture", {"topk": 50, "tau": 2.0}),
("C6", "capture", {"topk": 50, "tau": 5.0}),
("W1", "workload", {"frac": 0.05, "tau": 1.0}),
("W2", "workload", {"frac": 0.10, "tau": 1.0}),
)
EXPECTED_AGE_BINS = (
"[40.0, 45.0)",
"[45.0, 50.0)",
"[50.0, 55.0)",
"[55.0, 60.0)",
"[60.0, 65.0)",
"[65.0, 70.0)",
"[70.0, inf)",
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Compare trained models using 8 pre-registered primary metrics."
)
p.add_argument(
"--runs_root",
type=str,
default="runs",
help="Directory containing run subfolders (default: runs)",
)
p.add_argument(
"--pattern",
type=str,
default="*",
help="Glob pattern for run folder names under runs_root (default: *)",
)
p.add_argument(
"--run_dirs_file",
type=str,
default="",
help="Optional file with one run_dir per line (overrides runs_root/pattern)",
)
p.add_argument(
"--out_dir",
type=str,
default="comparison",
help="Output directory for tables (default: comparison)",
)
p.add_argument(
"--expect_n",
type=int,
default=12,
help="Expected number of models (default: 12). Set 0 to disable check.",
)
return p.parse_args()
def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]:
if not os.path.isdir(runs_root):
raise FileNotFoundError(f"runs_root not found: {runs_root}")
import glob
out: List[str] = []
for d in glob.glob(os.path.join(runs_root, pattern)):
if not os.path.isdir(d):
continue
if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))):
continue
out.append(d)
out.sort()
return out
def _read_run_dirs_file(path: str) -> List[str]:
if not os.path.isfile(path):
raise FileNotFoundError(f"run_dirs_file not found: {path}")
out: List[str] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
out.append(line)
return out
def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool:
return abs(float(a) - float(b)) <= tol
def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]:
"""Returns (event_weighted, macro, total_denom_events)."""
sub = df[(df["topk"] == int(topk)) & (
df["tau_years"].astype(float) == float(tau))].copy()
if sub.empty:
return float("nan"), float("nan"), 0
denom = sub["denom_events"].astype(np.int64)
numer = sub["numer_events"].astype(np.int64)
total_denom = int(denom.sum())
if total_denom > 0:
event_weighted = float(numer.sum() / total_denom)
else:
event_weighted = float("nan")
sub_pos = sub[denom > 0]
macro = float(sub_pos["capture_at_k"].astype(
float).mean()) if not sub_pos.empty else float("nan")
return event_weighted, macro, total_denom
def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]:
"""Returns (event_weighted, macro, total_events)."""
sub = df[(df["tau_years"].astype(float) == float(tau)) & (
df["frac_selected"].astype(float) == float(frac))].copy()
if sub.empty:
return float("nan"), float("nan"), 0
total_events = sub["total_events"].astype(np.int64)
captured = sub["events_captured"].astype(np.int64)
denom = int(total_events.sum())
if denom > 0:
event_weighted = float(captured.sum() / denom)
else:
event_weighted = float("nan")
sub_pos = sub[total_events > 0]
macro = float(sub_pos["capture_rate"].astype(
float).mean()) if not sub_pos.empty else float("nan")
return event_weighted, macro, denom
def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame:
"""Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name."""
df = pd.DataFrame({"model_name": model_names,
"value": values.astype(float)})
df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64))
df = df.sort_values(["is_nan", "value", "model_name"], ascending=[
True, False, True], kind="mergesort")
df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64)
n = len(df)
df["points"] = (n - df["rank"] + 1).astype(np.int64)
# Force NaNs to last (1 point)
df.loc[df["is_nan"], "points"] = 1
return df[["model_name", "rank", "points"]]
def main() -> None:
args = parse_args()
if args.run_dirs_file:
run_dirs = _read_run_dirs_file(args.run_dirs_file)
else:
run_dirs = _discover_run_dirs(args.runs_root, args.pattern)
if args.expect_n and len(run_dirs) != int(args.expect_n):
raise ValueError(
f"Expected {args.expect_n} runs, found {len(run_dirs)}")
os.makedirs(args.out_dir, exist_ok=True)
rows: List[Dict[str, object]] = []
warnings: List[str] = []
for run_dir in run_dirs:
model_name = os.path.basename(os.path.normpath(run_dir))
cfg = load_train_config(run_dir)
loss_type = str(cfg.get("loss_type", ""))
age_encoder = str(
cfg.get("age_encoder", cfg.get("age_encoder_type", "")))
full_cov = bool(cfg.get("full_cov", False))
covariate_setting = "fullcov" if full_cov else "partcov"
cap_path = os.path.join(run_dir, "horizon_capture.csv")
wy_path = os.path.join(run_dir, "workload_yield.csv")
if not os.path.isfile(cap_path) or not os.path.isfile(wy_path):
warnings.append(
f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv")
cap_df = pd.DataFrame()
wy_df = pd.DataFrame()
else:
cap_df = pd.read_csv(cap_path)
wy_df = pd.read_csv(wy_path)
# Validate age bins (best-effort). Missing bins are allowed (no eligible records),
# but unexpected bins likely indicate protocol mismatch.
if not cap_df.empty and "age_bin" in cap_df.columns:
seen_bins = sorted(set(str(x)
for x in cap_df["age_bin"].dropna().unique()))
unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS]
if unexpected:
warnings.append(
f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}")
out: Dict[str, object] = {
"model_name": model_name,
"run_dir": run_dir,
"loss_type": loss_type,
"age_encoder": age_encoder,
"covariate_setting": covariate_setting,
}
# Primary panel: event-weighted metrics
for metric_id, kind, spec in PRIMARY_METRICS:
if kind == "capture":
v_w, v_macro, denom = _capture_values(
cap_df, topk=spec["topk"], tau=spec["tau"])
out[metric_id] = v_w
out[f"{metric_id}_macro"] = v_macro
out[f"{metric_id}_denom"] = int(denom)
elif kind == "workload":
v_w, v_macro, denom = _workload_values(
wy_df, frac=spec["frac"], tau=spec["tau"])
out[metric_id] = v_w
out[f"{metric_id}_macro"] = v_macro
out[f"{metric_id}_denom"] = int(denom)
else:
raise ValueError(f"Unknown metric kind: {kind}")
rows.append(out)
metrics_df = pd.DataFrame(rows)
# Scoring: ranks + points for the 8 primary metrics (event-weighted)
per_metric_rank_parts: List[pd.DataFrame] = []
total_points = np.zeros((len(metrics_df),), dtype=np.int64)
for metric_id, _, _ in PRIMARY_METRICS:
r = _rank_points(metrics_df[metric_id], metrics_df["model_name"])
r = r.rename(columns={"rank": f"{metric_id}_rank",
"points": f"{metric_id}_points"})
per_metric_rank_parts.append(r)
ranks_df = per_metric_rank_parts[0]
for part in per_metric_rank_parts[1:]:
ranks_df = ranks_df.merge(part, on="model_name", how="left")
# Merge ranks back
merged = metrics_df.merge(ranks_df, on="model_name", how="left")
point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS]
merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64)
# Macro robustness score (not used for ranking, but used to flag instability)
macro_point_parts: List[pd.DataFrame] = []
for metric_id, _, _ in PRIMARY_METRICS:
r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"])
r = r.rename(columns={"rank": f"{metric_id}_macro_rank",
"points": f"{metric_id}_macro_points"})
macro_point_parts.append(r)
macro_ranks_df = macro_point_parts[0]
for part in macro_point_parts[1:]:
macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left")
merged = merged.merge(macro_ranks_df, on="model_name", how="left")
macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS]
merged["macro_total_score"] = merged[macro_point_cols].sum(
axis=1).astype(np.int64)
# Final leaderboard ranks
merged = merged.sort_values(["total_score", "model_name"], ascending=[
False, True], kind="mergesort")
merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64)
merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[
False, True], kind="mergesort")
macro_rank_map = {
name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())}
merged["macro_leaderboard_rank"] = merged["model_name"].map(
macro_rank_map).astype(np.int64)
merged["rank_delta_macro_minus_primary"] = (
merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64)
merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4
# Output tables
leaderboard_cols = [
"leaderboard_rank",
"model_name",
"loss_type",
"age_encoder",
"covariate_setting",
"total_score",
"macro_total_score",
"macro_leaderboard_rank",
"rank_delta_macro_minus_primary",
"macro_collapse_flag",
]
for mid, _, _ in PRIMARY_METRICS:
leaderboard_cols.append(mid)
out_leaderboard = merged[leaderboard_cols]
out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv")
out_leaderboard.to_csv(out_leaderboard_path, index=False)
# Audit table: per-metric ranks/points + values
audit_cols = ["model_name"]
for mid, _, _ in PRIMARY_METRICS:
audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"])
audit_cols.append("total_score")
out_audit = merged[audit_cols]
out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv")
out_audit.to_csv(out_audit_path, index=False)
# Dimension summaries
dim_summary_parts: List[pd.DataFrame] = []
for dim in ["loss_type", "age_encoder", "covariate_setting"]:
g = merged.groupby(dim, dropna=False)
dim_summary_parts.append(
g.agg(
n_models=("model_name", "count"),
mean_total_score=("total_score", "mean"),
std_total_score=("total_score", "std"),
best_total_score=("total_score", "max"),
).reset_index()
)
dim_summary = pd.concat(
[df.assign(dimension=name) for df, name in zip(dim_summary_parts, [
"loss_type", "age_encoder", "covariate_setting"])],
ignore_index=True,
)
out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv")
dim_summary.to_csv(out_dim_path, index=False)
# Matched-pair deltas for interpretability
def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame:
cols = list(fixed) + [var, "model_name", "total_score"]
df = merged[cols].copy()
pairs: List[Dict[str, object]] = []
for key, grp in df.groupby(list(fixed), dropna=False):
if len(grp) < 2:
continue
if grp[var].nunique() < 2:
continue
# deterministic ordering
grp = grp.sort_values([var, "model_name"], kind="mergesort")
base = grp.iloc[0]
for i in range(1, len(grp)):
other = grp.iloc[i]
pairs.append(
{
**({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}),
"var_dim": var,
"a_var": base[var],
"b_var": other[var],
"a_model": base["model_name"],
"b_model": other["model_name"],
"a_score": int(base["total_score"]),
"b_score": int(other["total_score"]),
"delta_b_minus_a": int(other["total_score"] - base["total_score"]),
}
)
return pd.DataFrame(pairs)
out_pairs_cov = _pair_deltas(
["loss_type", "age_encoder"], "covariate_setting")
out_pairs_age = _pair_deltas(
["loss_type", "covariate_setting"], "age_encoder")
out_pairs_loss = _pair_deltas(
["age_encoder", "covariate_setting"], "loss_type")
out_pairs_cov.to_csv(os.path.join(
args.out_dir, "pairs_covariates.csv"), index=False)
out_pairs_age.to_csv(os.path.join(
args.out_dir, "pairs_age_encoder.csv"), index=False)
out_pairs_loss.to_csv(os.path.join(
args.out_dir, "pairs_loss_type.csv"), index=False)
# Protocol stamp for auditability
protocol = {
"PRIMARY_METRICS": list(PRIMARY_METRICS),
"EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS),
"PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS),
"PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS),
"PRIMARY_WY_TAU": float(PRIMARY_WY_TAU),
"PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS),
"aggregation_primary": "event-weighted by denom_events / total_events",
"aggregation_macro": "macro-average over age bins with denom>0",
"nan_policy": "NaN ranked last (1 point)",
"tie_break": "model_name ascending (deterministic)",
}
pd.Series(protocol).to_json(os.path.join(args.out_dir,
"protocol.json"), indent=2, force_ascii=False)
if warnings:
with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f:
for w in warnings:
f.write(w + "\n")
print(f"Wrote {out_leaderboard_path}")
print(f"Wrote {out_audit_path}")
print(f"Wrote {out_dim_path}")
if warnings:
print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}")
if __name__ == "__main__":
main()