diff --git a/evaluate_models.py b/evaluate_models.py index 810e77f..5f56445 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -14,7 +14,6 @@ import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn -from losses import DiscreteTimeCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead @@ -24,6 +23,7 @@ from model import DelphiFork, SapDelphi, SimpleHead DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0] DAYS_PER_YEAR = 365.25 +DEFAULT_DEATH_CAUSE_ID = 1256 # ============================================================ @@ -683,7 +683,7 @@ def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[ # guard if p.size == 0: - return {"bins": [], "ece": float("nan"), "ici": float("nan")} + return {"bins": [], "ici": float("nan")} edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1)) # make strictly increasing where possible @@ -691,7 +691,6 @@ def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[ edges[-1] = np.inf bins = [] - ece = 0.0 ici_accum = 0.0 n = p.shape[0] @@ -701,25 +700,296 @@ def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[ continue p_mean = float(np.mean(p[mask])) y_mean = float(np.mean(y[mask])) - frac = float(np.mean(mask)) bins.append({"bin": i, "p_mean": p_mean, "y_mean": y_mean, "n": int(mask.sum())}) - ece += frac * abs(p_mean - y_mean) ici_accum += abs(p_mean - y_mean) ici = ici_accum / max(len(bins), 1) - return {"bins": bins, "ece": float(ece), "ici": float(ici)} + return {"bins": bins, "ici": float(ici)} -def count_ever_after_context_anytime( +def _safe_float(x: Any, default: float = float("nan")) -> float: + try: + return float(x) + except Exception: + return float(default) + + +def _ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def load_cause_names(path: str = "labels.csv") -> Dict[int, str]: + """Load 0-based cause_id -> name mapping. + + labels.csv is assumed to be one label per line, in disease-id order. + """ + if not os.path.exists(path): + return {} + mapping: Dict[int, str] = {} + with open(path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): + name = line.strip() + if name: + mapping[int(i)] = name + return mapping + + +def pick_focus_causes( + *, + counts_within_tau: Optional[np.ndarray], + n_disease: int, + death_cause_id: int = DEFAULT_DEATH_CAUSE_ID, + k: int = 5, +) -> List[int]: + """Pick focus causes for user-facing evaluation. + + Rule: + 1) Always include death_cause_id first. + 2) Then add K additional causes by descending event count if available. + If counts_within_tau is None, fall back to descending cause_id coverage proxy. + + Notes: + - counts_within_tau is expected to be shape (n_disease,). + - Deterministic: ties broken by smaller cause id. + """ + n_disease_i = int(n_disease) + if death_cause_id < 0 or death_cause_id >= n_disease_i: + print( + f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); " + "it will be omitted from focus causes." + ) + focus: List[int] = [] + else: + focus = [int(death_cause_id)] + + candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)] + + if counts_within_tau is not None: + c = np.asarray(counts_within_tau).astype(float) + if c.shape[0] != n_disease_i: + print( + "WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering." + ) + counts_within_tau = None + else: + # Sort by (-count, cause_id) + order = sorted(candidates, key=lambda i: (-float(c[i]), int(i))) + order = [i for i in order if float(c[i]) > 0] + focus.extend([int(i) for i in order[: int(k)]]) + + if counts_within_tau is None: + # Fallback: deterministic coverage proxy (descending id, excluding death), then take K. + # (Real coverage requires data; this path is mostly for robustness.) + order = sorted(candidates, key=lambda i: (-int(i))) + focus.extend([int(i) for i in order[: int(k)]]) + + # De-dup while preserving order + seen = set() + out: List[int] = [] + for cid in focus: + if cid not in seen: + out.append(cid) + seen.add(cid) + return out + + +def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None: + _ensure_dir(os.path.dirname(os.path.abspath(path)) or ".") + with open(path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows: + w.writerow(r) + + +def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]: + """Return list of (sex_label, mask) slices including an 'all' slice. + + If sex is missing, returns only ('all', None). + """ + out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)] + if sex is None: + return out + s = np.asarray(sex) + if s.ndim != 1: + return out + for val in [0, 1]: + m = (s == val) + if int(np.sum(m)) > 0: + out.append((str(val), m)) + return out + + +def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray: + edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1)) + edges = np.asarray(edges, dtype=float) + edges[0] = -np.inf + edges[-1] = np.inf + return edges + + +def compute_risk_stratification_bins( + p: np.ndarray, + y: np.ndarray, + *, + q_default: int = 10, +) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]: + """Compute quantile-based risk strata and a compact summary.""" + p = np.asarray(p, dtype=float) + y = np.asarray(y, dtype=float) + n = int(p.shape[0]) + if n == 0: + return 0, [], { + "y_overall": float("nan"), + "top_decile_y_rate": float("nan"), + "bottom_half_y_rate": float("nan"), + "lift_top10_vs_bottom50": float("nan"), + "slope_pred_vs_obs": float("nan"), + } + + # Choose quantiles robustly. + q = int(q_default) + if n < 200: + q = 5 + + edges = _quantile_edges(p, q) + y_overall = float(np.mean(y)) + bin_rows: List[Dict[str, Any]] = [] + p_means: List[float] = [] + y_rates: List[float] = [] + n_bins: List[int] = [] + + for i in range(q): + mask = (p > edges[i]) & (p <= edges[i + 1]) + nb = int(np.sum(mask)) + if nb == 0: + # Keep the row for consistent plotting; set NaNs. + bin_rows.append( + { + "q": int(i + 1), + "n_bin": 0, + "p_mean": float("nan"), + "y_rate": float("nan"), + "y_overall": y_overall, + "lift_vs_overall": float("nan"), + } + ) + continue + p_mean = float(np.mean(p[mask])) + y_rate = float(np.mean(y[mask])) + lift = float(y_rate / y_overall) if y_overall > 0 else float("nan") + bin_rows.append( + { + "q": int(i + 1), + "n_bin": nb, + "p_mean": p_mean, + "y_rate": y_rate, + "y_overall": y_overall, + "lift_vs_overall": lift, + } + ) + p_means.append(p_mean) + y_rates.append(y_rate) + n_bins.append(nb) + + # Summary + top_mask = (p > edges[q - 1]) & (p <= edges[q]) + bot_half_mask = (p > edges[0]) & (p <= edges[q // 2]) + top_y = float(np.mean(y[top_mask])) if int( + np.sum(top_mask)) > 0 else float("nan") + bot_y = float(np.mean(y[bot_half_mask])) if int( + np.sum(bot_half_mask)) > 0 else float("nan") + lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y) + and np.isfinite(bot_y) and bot_y > 0) else float("nan") + + slope = float("nan") + if len(p_means) >= 2: + # Weighted least squares slope of y_rate ~ p_mean. + x = np.asarray(p_means, dtype=float) + yy = np.asarray(y_rates, dtype=float) + w = np.asarray(n_bins, dtype=float) + xm = float(np.average(x, weights=w)) + ym = float(np.average(yy, weights=w)) + denom = float(np.sum(w * (x - xm) ** 2)) + if denom > 0: + slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom) + + summary = { + "y_overall": y_overall, + "top_decile_y_rate": top_y, + "bottom_half_y_rate": bot_y, + "lift_top10_vs_bottom50": lift_top_vs_bottom, + "slope_pred_vs_obs": slope, + } + return q, bin_rows, summary + + +def compute_capture_points( + p: np.ndarray, + y: np.ndarray, + k_pcts: Sequence[int], +) -> List[Dict[str, Any]]: + p = np.asarray(p, dtype=float) + y = np.asarray(y, dtype=float) + n = int(p.shape[0]) + if n == 0: + return [] + order = np.argsort(-p) + y_sorted = y[order] + events_total = float(np.sum(y_sorted)) + + rows: List[Dict[str, Any]] = [] + for k in k_pcts: + kf = float(k) + n_targeted = int(math.ceil(n * kf / 100.0)) + n_targeted = max(1, min(n_targeted, n)) + events_targeted = float(np.sum(y_sorted[:n_targeted])) + capture = float(events_targeted / + events_total) if events_total > 0 else float("nan") + precision = float(events_targeted / float(n_targeted)) + rows.append( + { + "k_pct": int(k), + "n_targeted": int(n_targeted), + "events_targeted": float(events_targeted), + "events_total": float(events_total), + "event_capture_rate": capture, + "precision_in_targeted": precision, + } + ) + return rows + + +def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]: + """Bucketize horizons into short/medium/long using the continuous-horizon rule.""" + uniq = sorted({float(h) for h in horizons}) + mapping: Dict[float, str] = {} + rows: List[Dict[str, Any]] = [] + # First 4 short, next 4 medium, rest long. + for i, h in enumerate(uniq): + if i < 4: + g, gr = "short", 1 + elif i < 8: + g, gr = "medium", 2 + else: + g, gr = "long", 3 + mapping[float(h)] = g + rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)}) + method = "continuous_unique_horizons_first4_next4_rest" + return rows, mapping, method + + +def count_occurs_within_horizon( loader: DataLoader, offset_years: float, + tau_years: float, n_disease: int, device: str, ) -> Tuple[np.ndarray, int]: - """Count per-person ever-occurrence for each disease after the prediction context. + """Count per-person occurrence within tau after the prediction context. - Returns counts[k] = number of individuals with disease k at least once after context. + Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau]. """ counts = torch.zeros((n_disease,), dtype=torch.long, device=device) n_total_eval = 0 @@ -735,19 +1005,29 @@ def count_ever_after_context_anytime( n_total_eval += int(keep.sum().item()) event_seq = event_seq[keep] + time_seq = time_seq[keep] t_ctx = t_ctx[keep] B, L = event_seq.shape + b = torch.arange(B, device=device) + t0 = time_seq[b, t_ctx] + t1 = t0 + (float(tau_years) * DAYS_PER_YEAR) + idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) - future = (idxs > t_ctx.unsqueeze(1)) & ( - event_seq >= 2) & (event_seq != 0) - if not future.any(): + in_window = ( + (idxs > t_ctx.unsqueeze(1)) + & (time_seq >= t0.unsqueeze(1)) + & (time_seq <= t1.unsqueeze(1)) + & (event_seq >= 2) + & (event_seq != 0) + ) + if not in_window.any(): continue - b_idx, t_idx = future.nonzero(as_tuple=True) + b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - # unique per (person, disease) to count per-person ever-occurrence + # unique per (person, disease) to count per-person within-window occurrence key = b_idx.to(torch.long) * int(n_disease) + disease_ids uniq = torch.unique(key) uniq_disease = uniq % int(n_disease) @@ -837,53 +1117,35 @@ def predict_cifs_for_model( offset_years: float, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Run model and produce: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Run model and produce cause-specific, time-dependent CIF outputs. Returns: - allcause_risk: (N,H) cause_cif: (N, topK, H) cif_full: (N, K, H) survival: (N, H) - sex: (N,) - y_allcause_tau: (N,H) - y_cause_ever_anytime: (N, topK) y_cause_within_tau: (N, topK, H) - y_cause_within_tau_max: (N, topK) - NOTE: - - y_cause_ever_anytime is Delphi2M-compatible case/control label. - - y_cause_within_tau_* corresponds to within-horizon labels (kept for legacy/secondary AUC). + NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk). """ backbone.eval() head.eval() # We will accumulate in CPU lists, then concat. - allcause_list: List[np.ndarray] = [] cause_cif_list: List[np.ndarray] = [] cif_full_list: List[np.ndarray] = [] survival_list: List[np.ndarray] = [] - sex_list: List[np.ndarray] = [] - y_all_list: List[np.ndarray] = [] - y_cause_ever_any_list: List[np.ndarray] = [] y_cause_within_list: List[np.ndarray] = [] - y_cause_within_tau_max_list: List[np.ndarray] = [] - - tau_max = float(max(eval_horizons)) + sex_list: List[np.ndarray] = [] top_cause_ids_t = torch.tensor( top_cause_ids, dtype=torch.long, device=device) - # Efficiency: pre-create horizons tensor once per model (on device) and vectorize comparisons. - eval_horizons_t = torch.tensor( - list(eval_horizons), device=device, dtype=torch.float32).view(1, -1) - for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) - sexes = sexes.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) @@ -895,7 +1157,7 @@ def predict_cifs_for_model( time_seq = time_seq[keep] cont_feats = cont_feats[keep] cate_feats = cate_feats[keep] - sexes_k = sexes[keep] + sexes_k = sexes[keep].to(device) t_ctx = t_ctx[keep] h = backbone(event_seq, time_seq, sexes_k, @@ -914,25 +1176,10 @@ def predict_cifs_for_model( else: raise ValueError(f"Unsupported loss_type: {loss_type}") - allcause = cif_full.sum(dim=1) # (B,H) cause_cif = cif_full.index_select( dim=1, index=top_cause_ids_t) # (B,topK,H) - # outcomes - dt_next, _cause_next = next_event_after_context( - event_seq, time_seq, t_ctx) - y_all = (dt_next.view(-1, 1) <= eval_horizons_t).to(torch.float32) - - # Delphi2M-compatible ever label (does not depend on horizon) - y_ever_any = multi_hot_ever_after_context_anytime( - event_seq=event_seq, - t_ctx=t_ctx, - n_disease=int(cif_full.size(1)), - ) - y_ever_any_top = y_ever_any.index_select( - dim=1, index=top_cause_ids_t).to(torch.float32) - - # Within-horizon labels for cause-specific CIF quality + legacy AUC + # Within-horizon labels for cause-specific CIF quality + discrimination. n_disease = int(cif_full.size(1)) y_within_top = torch.stack( [ @@ -948,41 +1195,25 @@ def predict_cifs_for_model( ], dim=2, ) # (B,topK,H) - y_within_tau_max_top = multi_hot_selected_causes_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=tau_max, - cause_ids=top_cause_ids_t, - n_disease=n_disease, - ).to(torch.float32) - allcause_list.append(allcause.detach().cpu().numpy()) cause_cif_list.append(cause_cif.detach().cpu().numpy()) cif_full_list.append(cif_full.detach().cpu().numpy()) survival_list.append(survival.detach().cpu().numpy()) - sex_list.append(sexes_k.detach().cpu().numpy()) - y_all_list.append(y_all.detach().cpu().numpy()) - y_cause_ever_any_list.append(y_ever_any_top.detach().cpu().numpy()) y_cause_within_list.append(y_within_top.detach().cpu().numpy()) - y_cause_within_tau_max_list.append( - y_within_tau_max_top.detach().cpu().numpy()) + sex_list.append(sexes_k.detach().cpu().numpy()) - if not allcause_list: + if not cause_cif_list: raise RuntimeError( "No valid samples for evaluation (all batches filtered out by offset).") - allcause_risk = np.concatenate(allcause_list, axis=0) cause_cif = np.concatenate(cause_cif_list, axis=0) cif_full = np.concatenate(cif_full_list, axis=0) survival = np.concatenate(survival_list, axis=0) - sex = np.concatenate(sex_list, axis=0) - y_allcause = np.concatenate(y_all_list, axis=0) - y_cause_ever_any = np.concatenate(y_cause_ever_any_list, axis=0) y_cause_within = np.concatenate(y_cause_within_list, axis=0) - y_cause_within_tau_max = np.concatenate(y_cause_within_tau_max_list, axis=0) + sex = np.concatenate( + sex_list, axis=0) if sex_list else np.array([], dtype=int) - return allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_any, y_cause_within, y_cause_within_tau_max + return cause_cif, cif_full, survival, y_cause_within, sex def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: @@ -994,13 +1225,8 @@ def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: def evaluate_one_model( model_name: str, - allcause_risk: np.ndarray, cause_cif: np.ndarray, - sex: np.ndarray, - y_allcause: np.ndarray, - y_cause_ever_anytime: np.ndarray, y_cause_within_tau: np.ndarray, - y_cause_within_tau_max: np.ndarray, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, out_rows: List[Dict[str, Any]], @@ -1009,175 +1235,16 @@ def evaluate_one_model( bootstrap_n: int, n_calib_bins: int = 10, ) -> None: - H = len(eval_horizons) - - # Task B (all-cause): Brier + AUC + calibration per horizon + # Cause-specific, time-dependent metrics per horizon. for h_i, tau in enumerate(eval_horizons): - p = allcause_risk[:, h_i] - y = y_allcause[:, h_i] - - out_rows.append( - { - "model_name": model_name, - "metric_name": "allcause_brier", - "horizon": float(tau), - "cause": "", - "value": brier_score(p, y), - "ci_low": "", - "ci_high": "", - } - ) - - if auc_ci_method == "none": - auc, lo, hi = float("nan"), float("nan"), float("nan") - auc = float("nan") - elif auc_ci_method == "bootstrap": - auc, lo, hi = bootstrap_auc_ci( - p, y, n_bootstrap=bootstrap_n, alpha=0.95) - else: - auc, lo, hi = delong_ci(y, p, alpha=0.95) - out_rows.append( - { - "model_name": model_name, - "metric_name": "allcause_auc", - "horizon": float(tau), - "cause": "", - "value": auc, - "ci_low": lo, - "ci_high": hi, - } - ) - - cal = calibration_deciles(p, y, n_bins=n_calib_bins) - out_rows.append( - { - "model_name": model_name, - "metric_name": "allcause_ece", - "horizon": float(tau), - "cause": "", - "value": cal["ece"], - "ci_low": "", - "ci_high": "", - } - ) - out_rows.append( - { - "model_name": model_name, - "metric_name": "allcause_ici", - "horizon": float(tau), - "cause": "", - "value": cal["ici"], - "ci_low": "", - "ci_high": "", - } - ) - - # Write calibration bins into a separate CSV (always for all-cause). - for binfo in cal.get("bins", []): - calib_rows.append( - { - "model_name": model_name, - "task": "all_cause", - "horizon": float(tau), - "cause_id": -1, - "bin_index": int(binfo["bin"]), - "p_mean": float(binfo["p_mean"]), - "y_mean": float(binfo["y_mean"]), - "n_in_bin": int(binfo["n"]), - } - ) - - # Stratification by sex - for s_val in [0, 1]: - m = sex == s_val - if np.sum(m) < 10: - continue - p_s = p[m] - y_s = y[m] - if auc_ci_method == "none": - auc_s, lo_s, hi_s = float("nan"), float("nan"), float("nan") - elif auc_ci_method == "bootstrap": - auc_s, lo_s, hi_s = bootstrap_auc_ci( - p_s, y_s, n_bootstrap=bootstrap_n, alpha=0.95) - else: - auc_s, lo_s, hi_s = delong_ci(y_s, p_s, alpha=0.95) - out_rows.append( - { - "model_name": model_name, - "metric_name": f"allcause_auc_sex{s_val}", - "horizon": float(tau), - "cause": "", - "value": auc_s, - "ci_low": lo_s, - "ci_high": hi_s, - } - ) - - # Task A (Delphi2M-compatible discrimination): per-cause AUC with EVER labels - # case/control is defined by whether the disease appears ANYTIME after context. - tau_max = float(max(eval_horizons)) - p_tau_max = cause_cif[:, :, -1] # (N, topK) - - for j, cause_id in enumerate(top_cause_ids.tolist()): - yk = y_cause_ever_anytime[:, j] - pk = p_tau_max[:, j] - if auc_ci_method == "none": - auc, lo, hi = float("nan"), float("nan"), float("nan") - elif auc_ci_method == "bootstrap": - auc, lo, hi = bootstrap_auc_ci( - pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) - else: - auc, lo, hi = delong_ci(yk, pk, alpha=0.95) - out_rows.append( - { - "model_name": model_name, - "metric_name": "cause_auc_ever", - "horizon": tau_max, - "cause": int(cause_id), - "value": auc, - "ci_low": lo, - "ci_high": hi, - } - ) - - # Keep the existing tau-window AUC as a separate metric (do not remove). - for j, cause_id in enumerate(top_cause_ids.tolist()): - yk = y_cause_within_tau_max[:, j] - pk = p_tau_max[:, j] - if auc_ci_method == "none": - auc, lo, hi = float("nan"), float("nan"), float("nan") - elif auc_ci_method == "bootstrap": - auc, lo, hi = bootstrap_auc_ci( - pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) - else: - auc, lo, hi = delong_ci(yk, pk, alpha=0.95) - out_rows.append( - { - "model_name": model_name, - "metric_name": "cause_auc", - "horizon": tau_max, - "cause": int(cause_id), - "value": auc, - "ci_low": lo, - "ci_high": hi, - } - ) - - # Task B additions: cause-specific Brier + calibration curves at tau=3.84 and 10.0 - tau_targets = [3.84, 10.0] - horizon_to_idx = {float(t): i for i, t in enumerate( - [float(x) for x in eval_horizons])} - for tau in tau_targets: - if float(tau) not in horizon_to_idx: - continue - h_idx = horizon_to_idx[float(tau)] - p_tau = cause_cif[:, :, h_idx] # (N, topK) - y_tau = y_cause_within_tau[:, :, h_idx] # (N, topK) + p_tau = cause_cif[:, :, h_i] # (N, topK) + y_tau = y_cause_within_tau[:, :, h_i] # (N, topK) for j, cause_id in enumerate(top_cause_ids.tolist()): p = p_tau[:, j] y = y_tau[:, j] + # Primary: CIF-based Brier score + ICI (calibration). out_rows.append( { "model_name": model_name, @@ -1190,18 +1257,7 @@ def evaluate_one_model( } ) - cal = calibration_deciles(p, y) - out_rows.append( - { - "model_name": model_name, - "metric_name": "cause_ece", - "horizon": float(tau), - "cause": int(cause_id), - "value": cal["ece"], - "ci_low": "", - "ci_high": "", - } - ) + cal = calibration_deciles(p, y, n_bins=n_calib_bins) out_rows.append( { "model_name": model_name, @@ -1214,7 +1270,27 @@ def evaluate_one_model( } ) - # Write cause calibration bins into separate CSV only for tau targets. + # Secondary: discrimination via AUC at the same horizon. + if auc_ci_method == "none": + auc, lo, hi = float("nan"), float("nan"), float("nan") + elif auc_ci_method == "bootstrap": + auc, lo, hi = bootstrap_auc_ci( + p, y, n_bootstrap=bootstrap_n, alpha=0.95) + else: + auc, lo, hi = delong_ci(y, p, alpha=0.95) + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_auc", + "horizon": float(tau), + "cause": int(cause_id), + "value": auc, + "ci_low": lo, + "ci_high": hi, + } + ) + + # Calibration curve bins for this cause + horizon. for binfo in cal.get("bins", []): calib_rows.append( { @@ -1305,6 +1381,21 @@ def main() -> int: choices=["delong", "bootstrap", "none"], ) ap.add_argument("--bootstrap_n", type=int, default=2000) + + # Export settings for user-facing experiments + ap.add_argument("--export_dir", type=str, default="eval_exports") + ap.add_argument("--death_cause_id", type=int, + default=DEFAULT_DEATH_CAUSE_ID) + ap.add_argument("--focus_k", type=int, default=5, + help="Additional non-death causes to include") + ap.add_argument("--capture_k_pcts", type=int, + nargs="*", default=[1, 5, 10, 20]) + ap.add_argument( + "--capture_curve_max_pct", + type=int, + default=50, + help="If >0, also export a dense capture curve for k=1..max_pct", + ) args = ap.parse_args() set_deterministic(args.seed) @@ -1313,7 +1404,12 @@ def main() -> int: if not specs: raise ValueError("No models provided") - # Determine top-K causes from the evaluation split only (model-agnostic). + export_dir = str(args.export_dir) + _ensure_dir(export_dir) + cause_names = load_cause_names("labels.csv") + + # Determine top-K causes from the evaluation split only (model-agnostic), + # aligned to time-dependent risk: occurrence within tau_max after context. first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] @@ -1334,32 +1430,69 @@ def main() -> int: collate_fn=health_collate_fn, ) - counts, n_total_eval = count_ever_after_context_anytime( + tau_max = float(max(args.eval_horizons)) + counts, n_total_eval = count_occurs_within_horizon( loader=loader_top, offset_years=args.offset_years, + tau_years=tau_max, n_disease=dataset_for_top.n_disease, device=args.device, ) - order = np.argsort(-counts) - order = order[counts[order] > 0] - top_cause_ids = order[: args.top_k_causes] - # Record top-cause counts under Delphi2M-compatible EVER label. + focus_causes = pick_focus_causes( + counts_within_tau=counts, + n_disease=int(dataset_for_top.n_disease), + death_cause_id=int(args.death_cause_id), + k=int(args.focus_k), + ) + top_cause_ids = np.asarray(focus_causes, dtype=int) + + # Export the chosen focus causes. + focus_rows: List[Dict[str, Any]] = [] + for r, cid in enumerate(focus_causes, start=1): + row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)} + if cid in cause_names: + row["cause_name"] = cause_names[cid] + focus_rows.append(row) + focus_fieldnames = ["cause", "rank"] + \ + (["cause_name"] if any("cause_name" in r for r in focus_rows) else []) + write_simple_csv(os.path.join(export_dir, "focus_causes.csv"), + focus_fieldnames, focus_rows) + + # Metadata for focus causes (within tau_max). top_causes_meta: List[Dict[str, Any]] = [] - for k in top_cause_ids.tolist(): - n_case = int(counts[int(k)]) + for cid in focus_causes: + n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0 top_causes_meta.append( { - "cause_id": int(k), - "n_case_ever": n_case, - "n_control_ever": int(n_total_eval - n_case), + "cause_id": int(cid), + "tau_years": float(tau_max), + "n_case_within_tau": n_case, + "n_control_within_tau": int(n_total_eval - n_case), "n_total_eval": int(n_total_eval), } ) + # Horizon groups for Experiment 3 + hg_rows, horizon_to_group, hg_method = make_horizon_groups( + args.eval_horizons) + write_simple_csv( + os.path.join(export_dir, "horizon_groups.csv"), + ["horizon", "group", "group_rank"], + hg_rows, + ) + rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] + # Experiment exports (accumulated across models) + rs_bins_rows: List[Dict[str, Any]] = [] + rs_sum_rows: List[Dict[str, Any]] = [] + cap_points_rows: List[Dict[str, Any]] = [] + cap_curve_rows: List[Dict[str, Any]] = [] + cal_group_sum_rows: List[Dict[str, Any]] = [] + cal_group_bins_rows: List[Dict[str, Any]] = [] + # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} @@ -1374,6 +1507,16 @@ def main() -> int: cfg = load_train_config_for_checkpoint(spec.checkpoint_path) + # Identifiers for consistent exports + model_id = str(spec.name) + model_type = str( + cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else "")) + loss_type_id = str( + cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else "")) + age_encoder = str(cfg.get("age_encoder", "")) + cov_type = "full" if _parse_bool( + cfg.get("full_cov", False)) else "partial" + cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset = HealthDataset( @@ -1400,15 +1543,11 @@ def main() -> int: head.load_state_dict(ckpt["head_state_dict"], strict=True) ( - allcause_risk, cause_cif, cif_full, survival, - sex, - y_allcause, - y_cause_ever_anytime, y_cause_within_tau, - y_cause_within_tau_max, + sex, ) = predict_cifs_for_model( backbone, head, @@ -1437,13 +1576,8 @@ def main() -> int: evaluate_one_model( model_name=spec.name, - allcause_risk=allcause_risk, cause_cif=cause_cif, - sex=sex, - y_allcause=y_allcause, - y_cause_ever_anytime=y_cause_ever_anytime, y_cause_within_tau=y_cause_within_tau, - y_cause_within_tau_max=y_cause_within_tau_max, eval_horizons=args.eval_horizons, top_cause_ids=top_cause_ids, out_rows=rows, @@ -1452,15 +1586,235 @@ def main() -> int: bootstrap_n=int(args.bootstrap_n), ) + # ============================================================ + # Experiment 1: Risk stratification bins + summary + # ============================================================ + for sex_label, sex_mask in _sex_slices(sex if sex.size else None): + for h_i, tau in enumerate(args.eval_horizons): + for j, cause_id in enumerate(top_cause_ids.tolist()): + p = cause_cif[:, j, h_i] + y = y_cause_within_tau[:, j, h_i] + if sex_mask is not None: + p = p[sex_mask] + y = y[sex_mask] + q_used, bin_rows, summary = compute_risk_stratification_bins( + p, y, q_default=10) + for br in bin_rows: + rs_bins_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "horizon": float(tau), + "sex": sex_label, + "q": int(br["q"]), + "n_bin": int(br["n_bin"]), + "p_mean": _safe_float(br["p_mean"]), + "y_rate": _safe_float(br["y_rate"]), + "y_overall": _safe_float(br["y_overall"]), + "lift_vs_overall": _safe_float(br["lift_vs_overall"]), + "q_total": int(q_used), + } + ) + rs_sum_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "horizon": float(tau), + "sex": sex_label, + "q_total": int(q_used), + "top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]), + "bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]), + "lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]), + "slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]), + } + ) + + # ============================================================ + # Experiment 2: High-risk capture points (+ optional curve) + # ============================================================ + k_pcts = [int(x) for x in args.capture_k_pcts] + curve_max = int(args.capture_curve_max_pct) + curve_grid = list(range(1, curve_max + 1) + ) if curve_max and curve_max > 0 else [] + for sex_label, sex_mask in _sex_slices(sex if sex.size else None): + for h_i, tau in enumerate(args.eval_horizons): + for j, cause_id in enumerate(top_cause_ids.tolist()): + p = cause_cif[:, j, h_i] + y = y_cause_within_tau[:, j, h_i] + if sex_mask is not None: + p = p[sex_mask] + y = y[sex_mask] + + for r in compute_capture_points(p, y, k_pcts): + cap_points_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "horizon": float(tau), + "sex": sex_label, + **r, + } + ) + if curve_grid: + for r in compute_capture_points(p, y, curve_grid): + cap_curve_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "horizon": float(tau), + "sex": sex_label, + **r, + } + ) + + # ============================================================ + # Experiment 3: Short/Medium/Long horizon-group calibration + # ============================================================ + # Per-horizon metrics for grouping + # Build a dict for quick access: (cause_id, horizon) -> (brier, ici) + per_h: Dict[Tuple[int, float], Dict[str, float]] = {} + for rr in rows[rows_start:]: + if rr.get("model_name") != spec.name: + continue + if rr.get("metric_name") not in {"cause_brier", "cause_ici"}: + continue + try: + cid = int(rr.get("cause")) + except Exception: + continue + h = _safe_float(rr.get("horizon")) + if not np.isfinite(h): + continue + key = (cid, float(h)) + d = per_h.get(key, {}) + d[str(rr.get("metric_name"))] = _safe_float(rr.get("value")) + per_h[key] = d + + # Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice). + for sex_label, sex_mask in _sex_slices(sex if sex.size else None): + for j, cause_id in enumerate(top_cause_ids.tolist()): + # Decide Q per slice for pooled reliability curve + n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int( + sex.shape[0]) + q_pool = 10 if n_slice >= 200 else 5 + + # Collect per-horizon brier/ici values + group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [ + ]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}} + group_n_total: Dict[str, int] = { + "short": 0, "medium": 0, "long": 0} + + # Pooled bins: group -> q -> accumulators + pooled: Dict[str, Dict[int, Dict[str, float]]] = { + "short": {}, "medium": {}, "long": {}} + + for h_i, tau in enumerate(args.eval_horizons): + g = horizon_to_group.get(float(tau), "long") + + # brier/ici per horizon (already computed at full-sample level) + d = per_h.get((int(cause_id), float(tau)), {}) + brier_h = _safe_float(d.get("cause_brier")) + ici_h = _safe_float(d.get("cause_ici")) + if np.isfinite(brier_h): + group_vals[g]["brier"].append(brier_h) + if np.isfinite(ici_h): + group_vals[g]["ici"].append(ici_h) + + # pooled reliability bins from raw p/y + p = cause_cif[:, j, h_i] + y = y_cause_within_tau[:, j, h_i] + if sex_mask is not None: + p = p[sex_mask] + y = y[sex_mask] + if p.size == 0: + continue + edges = _quantile_edges(p, q_pool) + for qi in range(q_pool): + m = (p > edges[qi]) & (p <= edges[qi + 1]) + nb = int(np.sum(m)) + if nb == 0: + continue + pm = float(np.mean(p[m])) + yr = float(np.mean(y[m])) + acc = pooled[g].get( + qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0}) + acc["n"] += float(nb) + acc["p_sum"] += float(nb) * pm + acc["y_sum"] += float(nb) * yr + pooled[g][qi + 1] = acc + group_n_total[g] = max(group_n_total[g], int(p.size)) + + for g in ["short", "medium", "long"]: + bvals = group_vals[g]["brier"] + ivals = group_vals[g]["ici"] + cal_group_sum_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "sex": sex_label, + "horizon_group": g, + "brier_mean": float(np.mean(bvals)) if bvals else float("nan"), + "brier_median": float(np.median(bvals)) if bvals else float("nan"), + "ici_mean": float(np.mean(ivals)) if ivals else float("nan"), + "ici_median": float(np.median(ivals)) if ivals else float("nan"), + "n_total": int(group_n_total[g]), + "horizon_grouping_method": hg_method, + } + ) + + for qi in range(1, q_pool + 1): + acc = pooled[g].get(qi) + if not acc or float(acc.get("n", 0.0)) <= 0: + continue + n_bin = float(acc["n"]) + cal_group_bins_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "cause": int(cause_id), + "sex": sex_label, + "horizon_group": g, + "q": int(qi), + "n_bin": int(n_bin), + "p_mean": float(acc["p_sum"] / n_bin), + "y_rate": float(acc["y_sum"] / n_bin), + "q_total": int(q_pool), + "horizon_grouping_method": hg_method, + } + ) + # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: rows.append( { "model_name": spec.name, - "metric_name": "topcause_n_case_ever", - "horizon": "", + "metric_name": "topcause_n_case_within_tau", + "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), - "value": int(tc["n_case_ever"]), + "value": int(tc["n_case_within_tau"]), "ci_low": "", "ci_high": "", } @@ -1468,10 +1822,10 @@ def main() -> int: rows.append( { "model_name": spec.name, - "metric_name": "topcause_n_control_ever", - "horizon": "", + "metric_name": "topcause_n_control_within_tau", + "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), - "value": int(tc["n_control_ever"]), + "value": int(tc["n_control_within_tau"]), "ci_low": "", "ci_high": "", } @@ -1480,7 +1834,7 @@ def main() -> int: { "model_name": spec.name, "metric_name": "topcause_n_total_eval", - "horizon": "", + "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_total_eval"]), "ci_low": "", @@ -1526,19 +1880,165 @@ def main() -> int: calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") write_calibration_bins_csv(calib_csv_path, calib_rows) + # Write experiment exports + write_simple_csv( + os.path.join(export_dir, "risk_stratification_bins.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "horizon", + "sex", + "q", + "n_bin", + "p_mean", + "y_rate", + "y_overall", + "lift_vs_overall", + "q_total", + ], + rs_bins_rows, + ) + write_simple_csv( + os.path.join(export_dir, "risk_stratification_summary.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "horizon", + "sex", + "q_total", + "top_decile_y_rate", + "bottom_half_y_rate", + "lift_top10_vs_bottom50", + "slope_pred_vs_obs", + ], + rs_sum_rows, + ) + write_simple_csv( + os.path.join(export_dir, "lift_capture_points.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "horizon", + "sex", + "k_pct", + "n_targeted", + "events_targeted", + "events_total", + "event_capture_rate", + "precision_in_targeted", + ], + cap_points_rows, + ) + if cap_curve_rows: + write_simple_csv( + os.path.join(export_dir, "lift_capture_curve.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "horizon", + "sex", + "k_pct", + "n_targeted", + "events_targeted", + "events_total", + "event_capture_rate", + "precision_in_targeted", + ], + cap_curve_rows, + ) + write_simple_csv( + os.path.join(export_dir, "calibration_groups_summary.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "sex", + "horizon_group", + "brier_mean", + "brier_median", + "ici_mean", + "ici_median", + "n_total", + "horizon_grouping_method", + ], + cal_group_sum_rows, + ) + write_simple_csv( + os.path.join(export_dir, "calibration_groups_bins.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "cause", + "sex", + "horizon_group", + "q", + "n_bin", + "p_mean", + "y_rate", + "q_total", + "horizon_grouping_method", + ], + cal_group_bins_rows, + ) + + # Manifest markdown (stable, user-facing) + manifest_path = os.path.join(export_dir, "eval_exports_manifest.md") + with open(manifest_path, "w", encoding="utf-8") as f: + f.write( + "# Evaluation Exports Manifest\n\n" + "This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). " + "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" + "## Files\n\n" + "- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n" + "- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n" + "- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n" + "- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n" + "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" + "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" + "- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n" + "- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n" + ) + meta = { "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], + "tau_max": float(tau_max), "top_k_causes": int(args.top_k_causes), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": integrity_meta, "notes": { - "task_a_label": "Delphi2M-compatible: disease occurs ANYTIME after context (ever in remaining sequence)", - "task_a_legacy_label": "Secondary: disease occurs within tau_max after context", - "task_b_label": "all-cause event within horizon (equivalent to next disease event within horizon)", + "label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])", + "primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)", + "secondary_metrics": "cause_auc (discrimination) with optional CI", + "exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported", "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", + "exports_dir": export_dir, + "focus_causes": focus_causes, + "horizon_grouping_method": hg_method, }, } with open(args.out_meta_json, "w") as f: