diff --git a/evaluate_models.py b/evaluate_models.py index 8a1ff36..83b783c 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -5,6 +5,9 @@ import math import os import random import statistics +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple @@ -522,20 +525,32 @@ def check_cif_integrity( # --- Standard fast DeLong AUC variance + CI (ties handled via midranks) --- def compute_midrank(x: np.ndarray) -> np.ndarray: + """Vectorized midrank computation (ties -> average ranks).""" x = np.asarray(x, dtype=float) - order = np.argsort(x) + n = int(x.shape[0]) + if n == 0: + return np.asarray([], dtype=float) + + order = np.argsort(x, kind="mergesort") z = x[order] - n = x.shape[0] - t = np.zeros(n, dtype=float) - i = 0 - while i < n: - j = i - while j < n and z[j] == z[i]: - j += 1 - t[i:j] = 0.5 * (i + j - 1) + 1.0 - i = j + + # Find tie groups in sorted order. + diff = np.diff(z) + # boundaries includes 0 and n + boundaries = np.concatenate( + [np.array([0], dtype=int), np.nonzero(diff != 0) + [0] + 1, np.array([n], dtype=int)] + ) + starts = boundaries[:-1] + ends = boundaries[1:] + lens = ends - starts + + # Midrank for each group in 1-based rank space. + mids = 0.5 * (starts + ends - 1) + 1.0 + t_sorted = np.repeat(mids, lens).astype(float, copy=False) + out = np.empty(n, dtype=float) - out[order] = t + out[order] = t_sorted return out @@ -678,34 +693,98 @@ def brier_score(p: np.ndarray, y: np.ndarray) -> float: def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: + bins, ici = _calibration_bins_and_ici( + p, y, n_bins=int(n_bins), return_bins=True) + return {"bins": bins, "ici": float(ici)} + + +def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float: + """Fast ICI only (no per-bin point export).""" + _, ici = _calibration_bins_and_ici( + p, y, n_bins=int(n_bins), return_bins=False) + return float(ici) + + +def _calibration_bins_and_ici( + p: np.ndarray, + y: np.ndarray, + *, + n_bins: int, + return_bins: bool, +) -> Tuple[List[Dict[str, Any]], float]: + """Vectorized quantile binning for calibration + ICI.""" p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) - - # guard if p.size == 0: - return {"bins": [], "ici": float("nan")} + return ([], float("nan")) if return_bins else ([], float("nan")) - edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1)) - # make strictly increasing where possible + q = np.linspace(0.0, 1.0, int(n_bins) + 1) + edges = np.quantile(p, q) + edges = np.asarray(edges, dtype=float) edges[0] = -np.inf edges[-1] = np.inf - bins = [] - ici_accum = 0.0 - n = p.shape[0] + # Bin assignment: i if edges[i] < p <= edges[i+1] + bin_idx = np.searchsorted(edges, p, side="right") - 1 + bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1) - for i in range(n_bins): - mask = (p > edges[i]) & (p <= edges[i + 1]) - if not np.any(mask): - continue - p_mean = float(np.mean(p[mask])) - y_mean = float(np.mean(y[mask])) - bins.append({"bin": i, "p_mean": p_mean, - "y_mean": y_mean, "n": int(mask.sum())}) - ici_accum += abs(p_mean - y_mean) + counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float) + sum_p = np.bincount(bin_idx, weights=p, + minlength=int(n_bins)).astype(float) + sum_y = np.bincount(bin_idx, weights=y, + minlength=int(n_bins)).astype(float) - ici = ici_accum / max(len(bins), 1) - return {"bins": bins, "ici": float(ici)} + nonempty = counts > 0 + if not np.any(nonempty): + return ([], float("nan")) if return_bins else ([], float("nan")) + + p_mean = np.zeros(int(n_bins), dtype=float) + y_mean = np.zeros(int(n_bins), dtype=float) + p_mean[nonempty] = sum_p[nonempty] / counts[nonempty] + y_mean[nonempty] = sum_y[nonempty] / counts[nonempty] + + diffs = np.abs(p_mean[nonempty] - y_mean[nonempty]) + ici = float(np.mean(diffs)) if diffs.size else float("nan") + + if not return_bins: + return [], ici + + bins: List[Dict[str, Any]] = [] + idxs = np.nonzero(nonempty)[0] + for i in idxs.tolist(): + bins.append( + { + "bin": int(i), + "p_mean": float(p_mean[i]), + "y_mean": float(y_mean[i]), + "n": int(counts[i]), + } + ) + return bins, ici + + +def _progress_line(done: int, total: int, prefix: str = "") -> str: + total_i = max(1, int(total)) + done_i = max(0, min(int(done), total_i)) + width = 28 + frac = done_i / total_i + filled = int(round(width * frac)) + bar = "#" * filled + "-" * (width - filled) + pct = 100.0 * frac + return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)" + + +def _should_show_progress(mode: str) -> bool: + m = str(mode).strip().lower() + if m in {"0", "false", "no", "none", "off"}: + return False + # Default: show if interactive. + if m in {"auto", "1", "true", "yes", "on", "bar"}: + try: + return bool(sys.stdout.isatty()) + except Exception: + return True + return True def _safe_float(x: Any, default: float = float("nan")) -> float: @@ -1116,15 +1195,14 @@ def predict_cifs_for_model( device: str, offset_years: float, eval_horizons: Sequence[float], - top_cause_ids: np.ndarray, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + n_disease: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Run model and produce cause-specific, time-dependent CIF outputs. Returns: - cause_cif: (N, topK, H) cif_full: (N, K, H) survival: (N, H) - y_cause_within_tau: (N, topK, H) + y_cause_within_tau: (N, K, H) NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk). """ @@ -1132,13 +1210,10 @@ def predict_cifs_for_model( head.eval() # We will accumulate in CPU lists, then concat. - cause_cif_list: List[np.ndarray] = [] cif_full_list: List[np.ndarray] = [] survival_list: List[np.ndarray] = [] y_cause_within_list: List[np.ndarray] = [] sex_list: List[np.ndarray] = [] - top_cause_ids_t = torch.tensor( - top_cause_ids, dtype=torch.long, device=device) for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch @@ -1177,95 +1252,124 @@ def predict_cifs_for_model( else: raise ValueError(f"Unsupported loss_type: {loss_type}") - cause_cif = cif_full.index_select( - dim=1, index=top_cause_ids_t) # (B,topK,H) - - # Within-horizon labels for cause-specific CIF quality + discrimination. - n_disease = int(cif_full.size(1)) - y_within_top = torch.stack( + # Within-horizon labels for all causes: disease k occurs within tau after context. + y_within_full = torch.stack( [ - multi_hot_selected_causes_within_horizon( + multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau), - cause_ids=top_cause_ids_t, - n_disease=n_disease, + n_disease=int(n_disease), ).to(torch.float32) for tau in eval_horizons ], dim=2, - ) # (B,topK,H) + ) # (B,K,H) - cause_cif_list.append(cause_cif.detach().cpu().numpy()) cif_full_list.append(cif_full.detach().cpu().numpy()) survival_list.append(survival.detach().cpu().numpy()) - y_cause_within_list.append(y_within_top.detach().cpu().numpy()) + y_cause_within_list.append(y_within_full.detach().cpu().numpy()) sex_list.append(sexes_k.detach().cpu().numpy()) - if not cause_cif_list: + if not cif_full_list: raise RuntimeError( "No valid samples for evaluation (all batches filtered out by offset).") - cause_cif = np.concatenate(cause_cif_list, axis=0) cif_full = np.concatenate(cif_full_list, axis=0) survival = np.concatenate(survival_list, axis=0) y_cause_within = np.concatenate(y_cause_within_list, axis=0) sex = np.concatenate( sex_list, axis=0) if sex_list else np.array([], dtype=int) - return cause_cif, cif_full, survival, y_cause_within, sex - - -def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: - counts = y_ever.sum(axis=0) - order = np.argsort(-counts) - order = order[counts[order] > 0] - return order[:top_k] + return cif_full, survival, y_cause_within, sex def evaluate_one_model( model_name: str, - cause_cif: np.ndarray, + cif_full: np.ndarray, y_cause_within_tau: np.ndarray, eval_horizons: Sequence[float], - top_cause_ids: np.ndarray, out_rows: List[Dict[str, Any]], calib_rows: List[Dict[str, Any]], + calib_cause_ids: Optional[Sequence[int]], auc_ci_method: str, bootstrap_n: int, n_calib_bins: int = 10, + metric_workers: int = 0, + progress: str = "auto", ) -> None: - # Cause-specific, time-dependent metrics per horizon. - for h_i, tau in enumerate(eval_horizons): - p_tau = cause_cif[:, :, h_i] # (N, topK) - y_tau = y_cause_within_tau[:, :, h_i] # (N, topK) + """Compute per-cause metrics for ALL diseases. - for j, cause_id in enumerate(top_cause_ids.tolist()): - p = p_tau[:, j] - y = y_tau[:, j] + Notes: + - Writes scalar metrics for all causes into out_rows. + - Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable). + """ + cif_full = np.asarray(cif_full, dtype=float) + y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float) + if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3: + raise ValueError( + "Expected cif_full and y_cause_within_tau with shape (N, K, H)") + if cif_full.shape != y_cause_within_tau.shape: + raise ValueError( + f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}" + ) - # Primary: CIF-based Brier score + ICI (calibration). - out_rows.append( + N, K, H = cif_full.shape + if H != len(eval_horizons): + raise ValueError("H mismatch between cif_full and eval_horizons") + + calib_set = set(int(x) + for x in calib_cause_ids) if calib_cause_ids is not None else set() + + workers = int(metric_workers) + if workers <= 0: + workers = int(min(8, os.cpu_count() or 1)) + workers = max(1, workers) + show_progress = _should_show_progress(progress) + + def _eval_chunk( + *, + tau: float, + p_tau: np.ndarray, + y_tau: np.ndarray, + brier_by_cause: np.ndarray, + cause_ids: np.ndarray, + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]: + local_rows: List[Dict[str, Any]] = [] + local_calib: List[Dict[str, Any]] = [] + for cid in cause_ids.tolist(): + p = p_tau[:, cid] + y = y_tau[:, cid] + + local_rows.append( { "model_name": model_name, "metric_name": "cause_brier", "horizon": float(tau), - "cause": int(cause_id), - "value": brier_score(p, y), + "cause": int(cid), + "value": float(brier_by_cause[cid]), "ci_low": "", "ci_high": "", } ) - cal = calibration_deciles(p, y, n_bins=n_calib_bins) - out_rows.append( + # ICI: compute bins only if we will export them. + need_bins = (not calib_set) or (int(cid) in calib_set) + if need_bins: + cal = calibration_deciles(p, y, n_bins=n_calib_bins) + ici = float(cal["ici"]) + else: + cal = None + ici = calibration_ici_only(p, y, n_bins=n_calib_bins) + + local_rows.append( { "model_name": model_name, "metric_name": "cause_ici", "horizon": float(tau), - "cause": int(cause_id), - "value": cal["ici"], + "cause": int(cid), + "value": float(ici), "ci_low": "", "ci_high": "", } @@ -1276,35 +1380,150 @@ def evaluate_one_model( auc, lo, hi = float("nan"), float("nan"), float("nan") elif auc_ci_method == "bootstrap": auc, lo, hi = bootstrap_auc_ci( - p, y, n_bootstrap=bootstrap_n, alpha=0.95) + p, y, n_bootstrap=bootstrap_n, alpha=0.95 + ) else: auc, lo, hi = delong_ci(y, p, alpha=0.95) - out_rows.append( + + local_rows.append( { "model_name": model_name, "metric_name": "cause_auc", "horizon": float(tau), - "cause": int(cause_id), - "value": auc, + "cause": int(cid), + "value": float(auc), "ci_low": lo, "ci_high": hi, } ) - # Calibration curve bins for this cause + horizon. - for binfo in cal.get("bins", []): - calib_rows.append( - { - "model_name": model_name, - "task": "cause_k", - "horizon": float(tau), - "cause_id": int(cause_id), - "bin_index": int(binfo["bin"]), - "p_mean": float(binfo["p_mean"]), - "y_mean": float(binfo["y_mean"]), - "n_in_bin": int(binfo["n"]), - } + if need_bins and cal is not None: + for binfo in cal.get("bins", []): + local_calib.append( + { + "model_name": model_name, + "task": "cause_k", + "horizon": float(tau), + "cause_id": int(cid), + "bin_index": int(binfo["bin"]), + "p_mean": float(binfo["p_mean"]), + "y_mean": float(binfo["y_mean"]), + "n_in_bin": int(binfo["n"]), + } + ) + return local_rows, local_calib, int(cause_ids.size) + + # Cause-specific, time-dependent metrics per horizon. + for h_i, tau in enumerate(eval_horizons): + p_tau = cif_full[:, :, h_i] # (N, K) + y_tau = y_cause_within_tau[:, :, h_i] # (N, K) + + # Vectorized Brier for speed. + brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,) + + # Parallelize disease-level metrics; chunk to avoid millions of futures. + all_ids = np.arange(int(K), dtype=int) + chunks = np.array_split(all_ids, workers) + done = 0 + prefix = f"[{model_name}] tau={float(tau)}y " + t0 = time.time() + + if workers <= 1: + for ch in chunks: + r_chunk, c_chunk, n_done = _eval_chunk( + tau=float(tau), + p_tau=p_tau, + y_tau=y_tau, + brier_by_cause=brier_by_cause, + cause_ids=ch, ) + out_rows.extend(r_chunk) + calib_rows.extend(c_chunk) + done += int(n_done) + if show_progress: + sys.stdout.write( + "\r" + _progress_line(done, int(K), prefix=prefix)) + sys.stdout.flush() + else: + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = [ + ex.submit( + _eval_chunk, + tau=float(tau), + p_tau=p_tau, + y_tau=y_tau, + brier_by_cause=brier_by_cause, + cause_ids=ch, + ) + for ch in chunks + if int(ch.size) > 0 + ] + for fut in as_completed(futs): + r_chunk, c_chunk, n_done = fut.result() + out_rows.extend(r_chunk) + calib_rows.extend(c_chunk) + done += int(n_done) + if show_progress: + sys.stdout.write( + "\r" + _progress_line(done, int(K), prefix=prefix)) + sys.stdout.flush() + + if show_progress: + dt = time.time() - t0 + sys.stdout.write("\r" + _progress_line(int(K), + int(K), prefix=prefix) + f" ({dt:.1f}s)\n") + sys.stdout.flush() + + +def summarize_over_diseases( + rows: List[Dict[str, Any]], + *, + model_name: str, + eval_horizons: Sequence[float], + metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"), +) -> List[Dict[str, Any]]: + """Summarize mean/median of each metric over diseases (per horizon).""" + out: List[Dict[str, Any]] = [] + # Build metric_name -> horizon -> list of values + bucket: Dict[Tuple[str, float], List[float]] = {} + for r in rows: + if r.get("model_name") != model_name: + continue + m = str(r.get("metric_name")) + if m not in set(metrics): + continue + h = _safe_float(r.get("horizon")) + v = _safe_float(r.get("value")) + if not np.isfinite(h): + continue + if not np.isfinite(v): + continue + bucket.setdefault((m, float(h)), []).append(float(v)) + + for tau in eval_horizons: + ht = float(tau) + for m in metrics: + vals = bucket.get((str(m), ht), []) + if vals: + arr = np.asarray(vals, dtype=float) + mean_v = float(np.mean(arr)) + med_v = float(np.median(arr)) + n_valid = int(arr.size) + else: + mean_v = float("nan") + med_v = float("nan") + n_valid = 0 + out.append( + { + "model_name": str(model_name), + "metric_name": str(m), + "horizon": ht, + "mean": mean_v, + "median": med_v, + "n_valid": n_valid, + } + ) + return out def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: @@ -1361,13 +1580,12 @@ def main() -> int: help="Anti-leakage offset (years)") ap.add_argument("--eval_horizons", type=float, nargs="*", default=DEFAULT_EVAL_HORIZONS) - ap.add_argument("--top_k_causes", type=int, default=50) ap.add_argument("--batch_size", type=int, default=128) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--seed", type=int, default=123) ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") - ap.add_argument("--out_csv", type=str, default="eval_results.csv") + ap.add_argument("--out_csv", type=str, default="eval_summary.csv") ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") # Integrity checks @@ -1383,6 +1601,21 @@ def main() -> int: ) ap.add_argument("--bootstrap_n", type=int, default=2000) + # Speed/UX + ap.add_argument( + "--metric_workers", + type=int, + default=0, + help="Threads for per-disease metrics (0=auto, 1=disable parallelism)", + ) + ap.add_argument( + "--progress", + type=str, + default="auto", + choices=["auto", "bar", "none"], + help="Progress visualization during per-disease evaluation", + ) + # Export settings for user-facing experiments ap.add_argument("--export_dir", type=str, default="eval_exports") ap.add_argument("--death_cause_id", type=int, @@ -1463,7 +1696,8 @@ def main() -> int: # Metadata for focus causes (within tau_max). top_causes_meta: List[Dict[str, Any]] = [] for cid in focus_causes: - n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0 + n_case = int(counts[int(cid)]) if int( + cid) < int(counts.shape[0]) else 0 top_causes_meta.append( { "cause_id": int(cid), @@ -1483,7 +1717,7 @@ def main() -> int: hg_rows, ) - rows: List[Dict[str, Any]] = [] + summary_rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] # Experiment exports (accumulated across models) @@ -1503,7 +1737,6 @@ def main() -> int: tag = _make_eval_tag(args.split, float(args.offset_years)) # Remember list offsets so we can write per-model slices to the model's run_dir. - rows_start = len(rows) calib_start = len(calib_rows) cfg = load_train_config_for_checkpoint(spec.checkpoint_path) @@ -1544,7 +1777,6 @@ def main() -> int: head.load_state_dict(ckpt["head_state_dict"], strict=True) ( - cause_cif, cif_full, survival, y_cause_within_tau, @@ -1558,7 +1790,7 @@ def main() -> int: args.device, args.offset_years, args.eval_horizons, - top_cause_ids, + n_disease=int(dataset.n_disease), ) # CIF integrity checks before metrics. @@ -1575,26 +1807,42 @@ def main() -> int: "integrity_notes": integrity_notes, } + # Per-disease metrics for ALL diseases (written into the model's run_dir). + model_rows: List[Dict[str, Any]] = [] evaluate_one_model( model_name=spec.name, - cause_cif=cause_cif, + cif_full=cif_full, y_cause_within_tau=y_cause_within_tau, eval_horizons=args.eval_horizons, - top_cause_ids=top_cause_ids, - out_rows=rows, + out_rows=model_rows, calib_rows=calib_rows, + calib_cause_ids=top_cause_ids.tolist(), auc_ci_method=str(args.auc_ci_method), bootstrap_n=int(args.bootstrap_n), + metric_workers=int(args.metric_workers), + progress=str(args.progress), ) + # Summary over diseases (mean/median per horizon). + model_summary_rows = summarize_over_diseases( + model_rows, + model_name=spec.name, + eval_horizons=args.eval_horizons, + ) + summary_rows.extend(model_summary_rows) + + # Convenience slices for user-facing experiments (focus causes only). + cause_cif_focus = cif_full[:, top_cause_ids, :] + y_within_focus = y_cause_within_tau[:, top_cause_ids, :] + # ============================================================ # Experiment 1: Risk stratification bins + summary # ============================================================ for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): for j, cause_id in enumerate(top_cause_ids.tolist()): - p = cause_cif[:, j, h_i] - y = y_cause_within_tau[:, j, h_i] + p = cause_cif_focus[:, j, h_i] + y = y_within_focus[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] @@ -1648,8 +1896,8 @@ def main() -> int: for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): for j, cause_id in enumerate(top_cause_ids.tolist()): - p = cause_cif[:, j, h_i] - y = y_cause_within_tau[:, j, h_i] + p = cause_cif_focus[:, j, h_i] + y = y_within_focus[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] @@ -1690,15 +1938,15 @@ def main() -> int: # Per-horizon metrics for grouping # Build a dict for quick access: (cause_id, horizon) -> (brier, ici) per_h: Dict[Tuple[int, float], Dict[str, float]] = {} - for rr in rows[rows_start:]: - if rr.get("model_name") != spec.name: - continue + for rr in model_rows: if rr.get("metric_name") not in {"cause_brier", "cause_ici"}: continue try: cid = int(rr.get("cause")) except Exception: continue + if cid not in set(int(x) for x in top_cause_ids.tolist()): + continue h = _safe_float(rr.get("horizon")) if not np.isfinite(h): continue @@ -1738,8 +1986,8 @@ def main() -> int: group_vals[g]["ici"].append(ici_h) # pooled reliability bins from raw p/y - p = cause_cif[:, j, h_i] - y = y_cause_within_tau[:, j, h_i] + p = cause_cif_focus[:, j, h_i] + y = y_within_focus[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] @@ -1809,7 +2057,7 @@ def main() -> int: # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: - rows.append( + model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_case_within_tau", @@ -1820,7 +2068,7 @@ def main() -> int: "ci_high": "", } ) - rows.append( + model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_control_within_tau", @@ -1831,7 +2079,7 @@ def main() -> int: "ci_high": "", } ) - rows.append( + model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_total_eval", @@ -1844,13 +2092,18 @@ def main() -> int: ) # Write per-model results into the model's run directory. - model_rows = rows[rows_start:] model_calib_rows = calib_rows[calib_start:] model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") + model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv") model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") write_results_csv(model_out_csv, model_rows) + write_simple_csv( + model_summary_csv, + ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], + model_summary_rows, + ) write_calibration_bins_csv(model_calib_csv, model_calib_rows) model_meta = { @@ -1860,12 +2113,13 @@ def main() -> int: "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], - "top_k_causes": int(args.top_k_causes), + "n_disease": int(dataset.n_disease), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": {spec.name: integrity_meta.get(spec.name, {})}, "paths": { "results_csv": model_out_csv, + "summary_csv": model_summary_csv, "calibration_bins_csv": model_calib_csv, }, } @@ -1874,7 +2128,12 @@ def main() -> int: print(f"Wrote per-model results to {model_out_csv}") - write_results_csv(args.out_csv, rows) + # Write global summary (across diseases) across all models. + write_simple_csv( + args.out_csv, + ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], + summary_rows, + ) # Write calibration curve points to a separate CSV. out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." @@ -2012,7 +2271,7 @@ def main() -> int: "This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). " "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" "## Files\n\n" - "- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n" + "- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n" "- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n" "- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n" "- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n" @@ -2027,7 +2286,7 @@ def main() -> int: "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "tau_max": float(tau_max), - "top_k_causes": int(args.top_k_causes), + "n_disease": int(dataset_for_top.n_disease), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": integrity_meta, @@ -2045,7 +2304,7 @@ def main() -> int: with open(args.out_meta_json, "w") as f: json.dump(meta, f, indent=2) - print(f"Wrote {args.out_csv} with {len(rows)} rows") + print(f"Wrote {args.out_csv} with {len(summary_rows)} rows") print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") print(f"Wrote {args.out_meta_json}") return 0