Files
DeepHealth/compare_models.py

428 lines
16 KiB
Python
Raw Permalink Normal View History

import argparse
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from utils import load_train_config
# -------------------------
# Pre-registered panel (DO NOT CHANGE AFTER SEEING RESULTS)
# -------------------------
PRIMARY_CAPTURE_TAUS = (1.0, 2.0, 5.0)
PRIMARY_CAPTURE_TOPKS = (20, 50)
PRIMARY_WY_TAU = 1.0
PRIMARY_WY_FRACS = (0.05, 0.10)
PRIMARY_METRICS = (
("C1", "capture", {"topk": 20, "tau": 1.0}),
("C2", "capture", {"topk": 20, "tau": 2.0}),
("C3", "capture", {"topk": 20, "tau": 5.0}),
("C4", "capture", {"topk": 50, "tau": 1.0}),
("C5", "capture", {"topk": 50, "tau": 2.0}),
("C6", "capture", {"topk": 50, "tau": 5.0}),
("W1", "workload", {"frac": 0.05, "tau": 1.0}),
("W2", "workload", {"frac": 0.10, "tau": 1.0}),
)
EXPECTED_AGE_BINS = (
"[40.0, 45.0)",
"[45.0, 50.0)",
"[50.0, 55.0)",
"[55.0, 60.0)",
"[60.0, 65.0)",
"[65.0, 70.0)",
"[70.0, inf)",
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Compare trained models using 8 pre-registered primary metrics."
)
p.add_argument(
"--runs_root",
type=str,
default="runs",
help="Directory containing run subfolders (default: runs)",
)
p.add_argument(
"--pattern",
type=str,
default="*",
help="Glob pattern for run folder names under runs_root (default: *)",
)
p.add_argument(
"--run_dirs_file",
type=str,
default="",
help="Optional file with one run_dir per line (overrides runs_root/pattern)",
)
p.add_argument(
"--out_dir",
type=str,
default="comparison",
help="Output directory for tables (default: comparison)",
)
p.add_argument(
"--expect_n",
type=int,
default=12,
help="Expected number of models (default: 12). Set 0 to disable check.",
)
return p.parse_args()
def _discover_run_dirs(runs_root: str, pattern: str) -> List[str]:
if not os.path.isdir(runs_root):
raise FileNotFoundError(f"runs_root not found: {runs_root}")
import glob
out: List[str] = []
for d in glob.glob(os.path.join(runs_root, pattern)):
if not os.path.isdir(d):
continue
if not (os.path.isfile(os.path.join(d, "best_model.pt")) and os.path.isfile(os.path.join(d, "train_config.json"))):
continue
out.append(d)
out.sort()
return out
def _read_run_dirs_file(path: str) -> List[str]:
if not os.path.isfile(path):
raise FileNotFoundError(f"run_dirs_file not found: {path}")
out: List[str] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
out.append(line)
return out
def _safe_float_eq(a: float, b: float, tol: float = 1e-9) -> bool:
return abs(float(a) - float(b)) <= tol
def _capture_values(df: pd.DataFrame, *, topk: int, tau: float) -> Tuple[float, float, int]:
"""Returns (event_weighted, macro, total_denom_events)."""
sub = df[(df["topk"] == int(topk)) & (
df["tau_years"].astype(float) == float(tau))].copy()
if sub.empty:
return float("nan"), float("nan"), 0
denom = sub["denom_events"].astype(np.int64)
numer = sub["numer_events"].astype(np.int64)
total_denom = int(denom.sum())
if total_denom > 0:
event_weighted = float(numer.sum() / total_denom)
else:
event_weighted = float("nan")
sub_pos = sub[denom > 0]
macro = float(sub_pos["capture_at_k"].astype(
float).mean()) if not sub_pos.empty else float("nan")
return event_weighted, macro, total_denom
def _workload_values(df: pd.DataFrame, *, frac: float, tau: float) -> Tuple[float, float, int]:
"""Returns (event_weighted, macro, total_events)."""
sub = df[(df["tau_years"].astype(float) == float(tau)) & (
df["frac_selected"].astype(float) == float(frac))].copy()
if sub.empty:
return float("nan"), float("nan"), 0
total_events = sub["total_events"].astype(np.int64)
captured = sub["events_captured"].astype(np.int64)
denom = int(total_events.sum())
if denom > 0:
event_weighted = float(captured.sum() / denom)
else:
event_weighted = float("nan")
sub_pos = sub[total_events > 0]
macro = float(sub_pos["capture_rate"].astype(
float).mean()) if not sub_pos.empty else float("nan")
return event_weighted, macro, denom
def _rank_points(values: pd.Series, model_names: pd.Series) -> pd.DataFrame:
"""Ranks models descending by value; NaN ranked last. Deterministic tie-break on model_name."""
df = pd.DataFrame({"model_name": model_names,
"value": values.astype(float)})
df["is_nan"] = ~np.isfinite(df["value"].to_numpy(dtype=np.float64))
df = df.sort_values(["is_nan", "value", "model_name"], ascending=[
True, False, True], kind="mergesort")
df["rank"] = np.arange(1, len(df) + 1, dtype=np.int64)
n = len(df)
df["points"] = (n - df["rank"] + 1).astype(np.int64)
# Force NaNs to last (1 point)
df.loc[df["is_nan"], "points"] = 1
return df[["model_name", "rank", "points"]]
def main() -> None:
args = parse_args()
if args.run_dirs_file:
run_dirs = _read_run_dirs_file(args.run_dirs_file)
else:
run_dirs = _discover_run_dirs(args.runs_root, args.pattern)
if args.expect_n and len(run_dirs) != int(args.expect_n):
raise ValueError(
f"Expected {args.expect_n} runs, found {len(run_dirs)}")
os.makedirs(args.out_dir, exist_ok=True)
rows: List[Dict[str, object]] = []
warnings: List[str] = []
for run_dir in run_dirs:
model_name = os.path.basename(os.path.normpath(run_dir))
cfg = load_train_config(run_dir)
loss_type = str(cfg.get("loss_type", ""))
age_encoder = str(
cfg.get("age_encoder", cfg.get("age_encoder_type", "")))
full_cov = bool(cfg.get("full_cov", False))
covariate_setting = "fullcov" if full_cov else "partcov"
cap_path = os.path.join(run_dir, "horizon_capture.csv")
wy_path = os.path.join(run_dir, "workload_yield.csv")
if not os.path.isfile(cap_path) or not os.path.isfile(wy_path):
warnings.append(
f"Missing evaluation outputs for {model_name}: horizon_capture.csv/workload_yield.csv")
cap_df = pd.DataFrame()
wy_df = pd.DataFrame()
else:
cap_df = pd.read_csv(cap_path)
wy_df = pd.read_csv(wy_path)
# Validate age bins (best-effort). Missing bins are allowed (no eligible records),
# but unexpected bins likely indicate protocol mismatch.
if not cap_df.empty and "age_bin" in cap_df.columns:
seen_bins = sorted(set(str(x)
for x in cap_df["age_bin"].dropna().unique()))
unexpected = [b for b in seen_bins if b not in EXPECTED_AGE_BINS]
if unexpected:
warnings.append(
f"{model_name}: unexpected age_bin labels in horizon_capture.csv: {unexpected}")
out: Dict[str, object] = {
"model_name": model_name,
"run_dir": run_dir,
"loss_type": loss_type,
"age_encoder": age_encoder,
"covariate_setting": covariate_setting,
}
# Primary panel: event-weighted metrics
for metric_id, kind, spec in PRIMARY_METRICS:
if kind == "capture":
v_w, v_macro, denom = _capture_values(
cap_df, topk=spec["topk"], tau=spec["tau"])
out[metric_id] = v_w
out[f"{metric_id}_macro"] = v_macro
out[f"{metric_id}_denom"] = int(denom)
elif kind == "workload":
v_w, v_macro, denom = _workload_values(
wy_df, frac=spec["frac"], tau=spec["tau"])
out[metric_id] = v_w
out[f"{metric_id}_macro"] = v_macro
out[f"{metric_id}_denom"] = int(denom)
else:
raise ValueError(f"Unknown metric kind: {kind}")
rows.append(out)
metrics_df = pd.DataFrame(rows)
# Scoring: ranks + points for the 8 primary metrics (event-weighted)
per_metric_rank_parts: List[pd.DataFrame] = []
total_points = np.zeros((len(metrics_df),), dtype=np.int64)
for metric_id, _, _ in PRIMARY_METRICS:
r = _rank_points(metrics_df[metric_id], metrics_df["model_name"])
r = r.rename(columns={"rank": f"{metric_id}_rank",
"points": f"{metric_id}_points"})
per_metric_rank_parts.append(r)
ranks_df = per_metric_rank_parts[0]
for part in per_metric_rank_parts[1:]:
ranks_df = ranks_df.merge(part, on="model_name", how="left")
# Merge ranks back
merged = metrics_df.merge(ranks_df, on="model_name", how="left")
point_cols = [f"{mid}_points" for mid, _, _ in PRIMARY_METRICS]
merged["total_score"] = merged[point_cols].sum(axis=1).astype(np.int64)
# Macro robustness score (not used for ranking, but used to flag instability)
macro_point_parts: List[pd.DataFrame] = []
for metric_id, _, _ in PRIMARY_METRICS:
r = _rank_points(merged[f"{metric_id}_macro"], merged["model_name"])
r = r.rename(columns={"rank": f"{metric_id}_macro_rank",
"points": f"{metric_id}_macro_points"})
macro_point_parts.append(r)
macro_ranks_df = macro_point_parts[0]
for part in macro_point_parts[1:]:
macro_ranks_df = macro_ranks_df.merge(part, on="model_name", how="left")
merged = merged.merge(macro_ranks_df, on="model_name", how="left")
macro_point_cols = [f"{mid}_macro_points" for mid, _, _ in PRIMARY_METRICS]
merged["macro_total_score"] = merged[macro_point_cols].sum(
axis=1).astype(np.int64)
# Final leaderboard ranks
merged = merged.sort_values(["total_score", "model_name"], ascending=[
False, True], kind="mergesort")
merged["leaderboard_rank"] = np.arange(1, len(merged) + 1, dtype=np.int64)
merged_macro = merged.sort_values(["macro_total_score", "model_name"], ascending=[
False, True], kind="mergesort")
macro_rank_map = {
name: int(i + 1) for i, name in enumerate(merged_macro["model_name"].tolist())}
merged["macro_leaderboard_rank"] = merged["model_name"].map(
macro_rank_map).astype(np.int64)
merged["rank_delta_macro_minus_primary"] = (
merged["macro_leaderboard_rank"] - merged["leaderboard_rank"]).astype(np.int64)
merged["macro_collapse_flag"] = merged["rank_delta_macro_minus_primary"] >= 4
# Output tables
leaderboard_cols = [
"leaderboard_rank",
"model_name",
"loss_type",
"age_encoder",
"covariate_setting",
"total_score",
"macro_total_score",
"macro_leaderboard_rank",
"rank_delta_macro_minus_primary",
"macro_collapse_flag",
]
for mid, _, _ in PRIMARY_METRICS:
leaderboard_cols.append(mid)
out_leaderboard = merged[leaderboard_cols]
out_leaderboard_path = os.path.join(args.out_dir, "leaderboard.csv")
out_leaderboard.to_csv(out_leaderboard_path, index=False)
# Audit table: per-metric ranks/points + values
audit_cols = ["model_name"]
for mid, _, _ in PRIMARY_METRICS:
audit_cols.extend([mid, f"{mid}_rank", f"{mid}_points", f"{mid}_denom"])
audit_cols.append("total_score")
out_audit = merged[audit_cols]
out_audit_path = os.path.join(args.out_dir, "per_metric_ranking.csv")
out_audit.to_csv(out_audit_path, index=False)
# Dimension summaries
dim_summary_parts: List[pd.DataFrame] = []
for dim in ["loss_type", "age_encoder", "covariate_setting"]:
g = merged.groupby(dim, dropna=False)
dim_summary_parts.append(
g.agg(
n_models=("model_name", "count"),
mean_total_score=("total_score", "mean"),
std_total_score=("total_score", "std"),
best_total_score=("total_score", "max"),
).reset_index()
)
dim_summary = pd.concat(
[df.assign(dimension=name) for df, name in zip(dim_summary_parts, [
"loss_type", "age_encoder", "covariate_setting"])],
ignore_index=True,
)
out_dim_path = os.path.join(args.out_dir, "dimension_summary.csv")
dim_summary.to_csv(out_dim_path, index=False)
# Matched-pair deltas for interpretability
def _pair_deltas(fixed: Sequence[str], var: str) -> pd.DataFrame:
cols = list(fixed) + [var, "model_name", "total_score"]
df = merged[cols].copy()
pairs: List[Dict[str, object]] = []
for key, grp in df.groupby(list(fixed), dropna=False):
if len(grp) < 2:
continue
if grp[var].nunique() < 2:
continue
# deterministic ordering
grp = grp.sort_values([var, "model_name"], kind="mergesort")
base = grp.iloc[0]
for i in range(1, len(grp)):
other = grp.iloc[i]
pairs.append(
{
**({fixed[j]: key[j] for j in range(len(fixed))} if isinstance(key, tuple) else {fixed[0]: key}),
"var_dim": var,
"a_var": base[var],
"b_var": other[var],
"a_model": base["model_name"],
"b_model": other["model_name"],
"a_score": int(base["total_score"]),
"b_score": int(other["total_score"]),
"delta_b_minus_a": int(other["total_score"] - base["total_score"]),
}
)
return pd.DataFrame(pairs)
out_pairs_cov = _pair_deltas(
["loss_type", "age_encoder"], "covariate_setting")
out_pairs_age = _pair_deltas(
["loss_type", "covariate_setting"], "age_encoder")
out_pairs_loss = _pair_deltas(
["age_encoder", "covariate_setting"], "loss_type")
out_pairs_cov.to_csv(os.path.join(
args.out_dir, "pairs_covariates.csv"), index=False)
out_pairs_age.to_csv(os.path.join(
args.out_dir, "pairs_age_encoder.csv"), index=False)
out_pairs_loss.to_csv(os.path.join(
args.out_dir, "pairs_loss_type.csv"), index=False)
# Protocol stamp for auditability
protocol = {
"PRIMARY_METRICS": list(PRIMARY_METRICS),
"EXPECTED_AGE_BINS": list(EXPECTED_AGE_BINS),
"PRIMARY_CAPTURE_TAUS": list(PRIMARY_CAPTURE_TAUS),
"PRIMARY_CAPTURE_TOPKS": list(PRIMARY_CAPTURE_TOPKS),
"PRIMARY_WY_TAU": float(PRIMARY_WY_TAU),
"PRIMARY_WY_FRACS": list(PRIMARY_WY_FRACS),
"aggregation_primary": "event-weighted by denom_events / total_events",
"aggregation_macro": "macro-average over age bins with denom>0",
"nan_policy": "NaN ranked last (1 point)",
"tie_break": "model_name ascending (deterministic)",
}
pd.Series(protocol).to_json(os.path.join(args.out_dir,
"protocol.json"), indent=2, force_ascii=False)
if warnings:
with open(os.path.join(args.out_dir, "warnings.txt"), "w", encoding="utf-8") as f:
for w in warnings:
f.write(w + "\n")
print(f"Wrote {out_leaderboard_path}")
print(f"Wrote {out_audit_path}")
print(f"Wrote {out_dim_path}")
if warnings:
print(f"Wrote {os.path.join(args.out_dir, 'warnings.txt')}")
if __name__ == "__main__":
main()