From 4686b56336d439cd6279c01a300bd4c45369197d Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 17 Jan 2026 23:34:19 +0800 Subject: [PATCH] Refactor next-event evaluation to use next-token scores and implement clean control AUC metrics --- evaluate_next_event.py | 212 ++++++++++++------------------------ utils.py | 241 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 286 insertions(+), 167 deletions(-) diff --git a/evaluate_next_event.py b/evaluate_next_event.py index 76bef5a..b5b0e61 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -23,20 +23,18 @@ from utils import ( load_checkpoint_into, load_train_config, parse_float_list, - predict_cifs, - roc_auc_ovr, + predict_next_token_logits, + get_auc_delong_var, seed_everything, - topk_indices, DAYS_PER_YEAR, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( - description="Evaluate next-event prediction using short-window CIF" + description="Evaluate next-event prediction using next-token scores" ) p.add_argument("--run_dir", type=str, required=True) - p.add_argument("--tau_short", type=float, required=True, help="years") p.add_argument( "--age_bins", type=str, @@ -73,140 +71,69 @@ def _format_age_bin_label(lo: float, hi: float) -> str: return f"[{lo}, {hi})" -def _compute_next_event_metrics( - *, - scores: np.ndarray, - y_next: np.ndarray, - tau_short: float, - min_pos: int, -) -> tuple[list[dict], pd.DataFrame]: - """Compute next-event *primary* metrics on a given subset. - - Implements 评估方案.md (Next-event): - - score_k = CIF_k(tau_short) - - Hit@K / MRR are computed on records with an observed next-event. - - Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table - based on whether the cause occurs within (t0, t0+tau_short] (display-only). - """ - n_records_total = int(y_next.size) - eligible = y_next >= 0 - n_eligible = int(eligible.sum()) - coverage = float( - n_eligible / n_records_total) if n_records_total > 0 else 0.0 - - metrics_rows: List[dict] = [] - metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) - metrics_rows.append( - {"metric": "n_next_event_eligible", "value": n_eligible}) - metrics_rows.append({"metric": "coverage", "value": coverage}) - metrics_rows.append( - {"metric": "tau_short_years", "value": float(tau_short)}) - - K = int(scores.shape[1]) - # Diagnostic: build per-cause AUC using within-window labels. - # This is NOT a primary metric (no IPCW / censoring adjustment). - diag_df = pd.DataFrame( - { - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": np.zeros((K,), dtype=np.int64), - "n_neg": np.zeros((K,), dtype=np.int64), - "auc": np.full((K,), np.nan, dtype=np.float64), - "included": np.zeros((K,), dtype=bool), - } - ) - if n_records_total == 0: - metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")}) - metrics_rows.append({"metric": "mrr", "value": float("nan")}) - for k in [1, 3, 5, 10, 20]: - metrics_rows.append( - {"metric": f"hitrate_at_{k}", "value": float("nan")}) - return metrics_rows, diag_df - - # If no eligible, keep coverage but leave accuracy-like metrics as NaN. - if n_eligible == 0: - metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")}) - metrics_rows.append({"metric": "mrr", "value": float("nan")}) - for k in [1, 3, 5, 10, 20]: - metrics_rows.append( - {"metric": f"hitrate_at_{k}", "value": float("nan")}) - return metrics_rows, diag_df - - scores_e = scores[eligible] - y_e = y_next[eligible] - - pred = scores_e.argmax(axis=1) - acc = float((pred == y_e).mean()) - metrics_rows.append({"metric": "hitrate_at_1", "value": acc}) - - # MRR - order = np.argsort(-scores_e, axis=1, kind="mergesort") - ranks = np.empty(y_e.shape[0], dtype=np.int32) - for i in range(y_e.shape[0]): - ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 - mrr = float((1.0 / ranks).mean()) - metrics_rows.append({"metric": "mrr", "value": mrr}) - - # HitRate@K - for k in [1, 3, 5, 10, 20]: - topk = topk_indices(scores_e, k) - hit = (topk == y_e[:, None]).any(axis=1) - metrics_rows.append({"metric": f"hitrate_at_{k}", - "value": float(hit.mean())}) - - # Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here. - _ = min_pos - return metrics_rows, diag_df - - -def _compute_within_window_auc( +def _compute_next_event_auc_clean_control( *, scores: np.ndarray, records: list, - tau_short: float, - min_pos: int, ) -> pd.DataFrame: - """Diagnostic-only per-cause AUC. + """Delphi-2M next-event AUC (clean control) per cause. - Label definition (event-driven, approximate; no IPCW): - y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short]. + Definitions per cause k: + - Case: next_event_cause == k + - Control (clean): next_event_cause != k AND k not in record.lifetime_causes + AUC is computed with DeLong variance. """ n_records = int(len(records)) if n_records == 0: return pd.DataFrame( - columns=["cause_id", "n_pos", "n_neg", "auc", "included"], + columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"], ) K = int(scores.shape[1]) - y = np.zeros((n_records, K), dtype=np.int8) - tau = float(tau_short) - - # Build labels from future events. - for i, r in enumerate(records): - if r.future_causes.size == 0: - continue - m = r.future_dt_years <= tau - if not np.any(m): - continue - y[i, r.future_causes[m]] = 1 - - n_pos = y.sum(axis=0).astype(np.int64) - n_neg = (int(n_records) - n_pos).astype(np.int64) + y_next = np.array( + [(-1 if r.next_event_cause is None else int(r.next_event_cause)) + for r in records], + dtype=np.int64, + ) auc = np.full((K,), np.nan, dtype=np.float64) - candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) - for k in candidates: - auc[k] = roc_auc_ovr(y[:, k].astype(np.int32), - scores[:, k].astype(np.float64)) + var = np.full((K,), np.nan, dtype=np.float64) + n_case = np.zeros((K,), dtype=np.int64) + n_control = np.zeros((K,), dtype=np.int64) + + for k in range(K): + case_mask = y_next == k + if not np.any(case_mask): + continue + + # Clean controls: not next-event k AND never had k in their lifetime history. + control_mask = y_next != k + if np.any(control_mask): + clean = np.fromiter( + ((k not in rec.lifetime_causes) for rec in records), + dtype=bool, + count=n_records, + ) + control_mask = control_mask & clean + + cs = scores[case_mask, k] + hs = scores[control_mask, k] + n_case[k] = int(cs.size) + n_control[k] = int(hs.size) + if cs.size == 0 or hs.size == 0: + continue + + a, v = get_auc_delong_var(hs, cs) + auc[k] = float(a) + var[k] = float(v) - included = (n_pos >= int(min_pos)) & (n_neg > 0) return pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), - "n_pos": n_pos, - "n_neg": n_neg, + "n_case": n_case, + "n_control": n_control, "auc": auc, - "included": included, + "auc_variance": var, } ) @@ -247,20 +174,15 @@ def main() -> None: **dl_kwargs, ) - tau = float(args.tau_short) - scores = predict_cifs( + scores = predict_next_token_logits( model, head, - criterion, loader, - [tau], device=device, show_progress=show_progress, - progress_desc="Inference (next-event)", + progress_desc="Inference (next-token)", + return_probs=True, ) - # scores shape: (N,K,1) for multi-taus; squeeze last - if scores.ndim == 3: - scores = scores[:, :, 0] y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) @@ -284,24 +206,24 @@ def main() -> None: label = _format_age_bin_label(lo, hi) m = bin_idx == b m_scores = scores[m] - m_y = y_next[m] m_records = [r for r, keep in zip(records, m) if bool(keep)] - m_rows, m_pc = _compute_next_event_metrics( - scores=m_scores, - y_next=m_y, - tau_short=tau, - min_pos=int(args.min_pos), - ) - for row in m_rows: - per_bin_metric_rows.append({"age_bin": label, **row}) - m_auc = _compute_within_window_auc( + # Coverage metric for transparency (not Delphi-2M AUC itself). + m_y = y_next[m] + n_total = int(m_y.size) + n_eligible = int((m_y >= 0).sum()) + coverage = float(n_eligible / n_total) if n_total > 0 else 0.0 + per_bin_metric_rows.append( + {"age_bin": label, "metric": "n_records_total", "value": n_total}) + per_bin_metric_rows.append( + {"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible}) + per_bin_metric_rows.append( + {"age_bin": label, "metric": "coverage", "value": coverage}) + + m_auc = _compute_next_event_auc_clean_control( scores=m_scores, records=m_records, - tau_short=tau, - min_pos=int(args.min_pos), ) m_auc.insert(0, "age_bin", label) - m_auc.insert(1, "tau_short_years", float(tau)) per_bin_auc_parts.append(m_auc) out_metrics_bins = os.path.join( @@ -313,11 +235,11 @@ def main() -> None: pd.concat(per_bin_auc_parts, ignore_index=True).to_csv( out_auc_bins, index=False) else: - pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos", - "n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False) + pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control", + "auc", "auc_variance"]).to_csv(out_auc_bins, index=False) - print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.") - print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).") + print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.") + print("EVAL METHOD: DeLong AUC variance is reported (per cause).") print(f"Wrote {out_metrics_bins}") print(f"Wrote {out_auc_bins}") diff --git a/utils.py b/utils.py index 593072b..2ed23c9 100644 --- a/utils.py +++ b/utils.py @@ -273,6 +273,8 @@ class EvalRecord: cutoff_pos: int # baseline position (inclusive) next_event_cause: Optional[int] next_event_dt_years: Optional[float] + # (U,) unique causes ever observed (clean-control filtering) + lifetime_causes: np.ndarray future_causes: np.ndarray # (E,) in [0..K-1] future_dt_years: np.ndarray # (E,) strictly > 0 @@ -320,7 +322,19 @@ def build_event_driven_records( doa_days = float(times_ins[int(doa_pos[0])]) is_disease = codes_ins >= N_TECH_TOKENS - disease_times = times_ins[is_disease] + + # Lifetime (ever) disease history for Clean Control filtering. + if np.any(is_disease): + lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( + np.int64, copy=False + ) + lifetime_causes = np.unique(lifetime_causes) + else: + lifetime_causes = np.zeros((0,), dtype=np.int64) + + disease_pos_all = np.flatnonzero(is_disease) + disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros( + (0,), dtype=np.float64) for b in range(len(age_bins_days) - 1): lo = age_bins_days[b] @@ -331,29 +345,22 @@ def build_event_driven_records( if not (doa_days <= hi): continue - # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA - in_bin = (disease_times >= lo) & ( - disease_times < hi) & (disease_times >= doa_days) - cand_times = disease_times[in_bin] - if cand_times.size == 0: + # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA. + # Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin). + if disease_pos_all.size == 0: continue - t0_days = float(rng.choice(cand_times)) + in_bin = ( + (disease_times_all >= lo) + & (disease_times_all < hi) + & (disease_times_all >= doa_days) + ) + cand_pos = disease_pos_all[in_bin] + if cand_pos.size == 0: + continue - # Baseline position (inclusive) in the *post-DOA-inserted* sequence. - pos = np.flatnonzero(is_disease & np.isclose( - times_ins, t0_days, rtol=0.0, atol=eps)) - if pos.size == 0: - disease_pos = np.flatnonzero(is_disease) - if disease_pos.size == 0: - continue - disease_times_full = times_ins[disease_pos] - closest_idx = int( - np.argmin(np.abs(disease_times_full - t0_days))) - cutoff_pos = int(disease_pos[closest_idx]) - t0_days = float(disease_times_full[closest_idx]) - else: - cutoff_pos = int(pos[0]) + cutoff_pos = int(rng.choice(cand_pos)) + t0_days = float(times_ins[cutoff_pos]) # Future disease events strictly after t0 future_mask = (times_ins > (t0_days + eps)) & is_disease @@ -366,7 +373,8 @@ def build_event_driven_records( else: future_times_days = times_ins[future_pos] future_tokens = codes_ins[future_pos] - future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64) + future_causes = ( + future_tokens - N_TECH_TOKENS).astype(np.int64) future_dt_years_arr = ( (future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32) @@ -383,6 +391,7 @@ def build_event_driven_records( cutoff_pos=int(cutoff_pos), next_event_cause=next_cause, next_event_dt_years=next_dt_years, + lifetime_causes=lifetime_causes, future_causes=future_causes, future_dt_years=future_dt_years_arr, ) @@ -575,3 +584,191 @@ def topk_indices(scores: np.ndarray, k: int) -> np.ndarray: part_scores = np.take_along_axis(scores, part, axis=1) order = np.argsort(-part_scores, axis=1, kind="mergesort") return np.take_along_axis(part, order, axis=1) + + +# ------------------------- +# Statistical evaluation (DeLong) +# ------------------------- + +def compute_midrank(x: np.ndarray) -> np.ndarray: + """Compute midranks of a 1D array (1-based ranks, tie-aware).""" + x = np.asarray(x, dtype=np.float64) + if x.ndim != 1: + raise ValueError("compute_midrank expects a 1D array") + + order = np.argsort(x, kind="mergesort") + x_sorted = x[order] + n = int(x_sorted.size) + + midranks = np.empty((n,), dtype=np.float64) + i = 0 + while i < n: + j = i + while j < n and x_sorted[j] == x_sorted[i]: + j += 1 + # ranks are 1..n; average over ties + mid = 0.5 * ((i + 1) + j) + midranks[i:j] = mid + i = j + + out = np.empty((n,), dtype=np.float64) + out[order] = midranks + return out + + +def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: + """Fast DeLong method for AUC covariance. + + Args: + predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first + label_1_count examples are positives. + label_1_count: number of positive examples. + Returns: + (aucs, delong_cov) + """ + preds = np.asarray(predictions_sorted_transposed, dtype=np.float64) + if preds.ndim != 2: + raise ValueError("predictions_sorted_transposed must be 2D") + + m = int(label_1_count) + n = int(preds.shape[1] - m) + if m <= 0 or n <= 0: + raise ValueError("DeLong requires at least 1 positive and 1 negative") + + k = int(preds.shape[0]) + tx = np.empty((k, m), dtype=np.float64) + ty = np.empty((k, n), dtype=np.float64) + tz = np.empty((k, m + n), dtype=np.float64) + + for r in range(k): + tx[r] = compute_midrank(preds[r, :m]) + ty[r] = compute_midrank(preds[r, m:]) + tz[r] = compute_midrank(preds[r, :]) + + aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) + + v01 = (tz[:, :m] - tx) / float(n) + v10 = 1.0 - (tz[:, m:] - ty) / float(m) + + # np.cov expects variables in rows by default when rowvar=True. + sx = np.cov(v01, rowvar=True, bias=False) + sy = np.cov(v10, rowvar=True, bias=False) + delong_cov = sx / float(m) + sy / float(n) + return aucs, delong_cov + + +def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]: + """Return ordering that places positives first and label_1_count.""" + y = np.asarray(ground_truth, dtype=np.int32) + if y.ndim != 1: + raise ValueError("ground_truth must be 1D") + label_1_count = int(y.sum()) + order = np.argsort(-y, kind="mergesort") + return order, label_1_count + + +def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]: + """Compute AUC and its DeLong variance. + + Args: + healthy_scores: scores for controls (label=0) + diseased_scores: scores for cases (label=1) + Returns: + (auc, auc_variance) + """ + h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1) + d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1) + n0 = int(h.size) + n1 = int(d.size) + if n0 == 0 or n1 == 0: + return float("nan"), float("nan") + + # Arrange positives first as required by fastDeLong. + scores = np.concatenate([d, h], axis=0) + gt = np.concatenate([ + np.ones((n1,), dtype=np.int32), + np.zeros((n0,), dtype=np.int32), + ]) + order, label_1_count = compute_ground_truth_statistics(gt) + preds_sorted = scores[order][None, :] + aucs, cov = fastDeLong(preds_sorted, label_1_count) + auc = float(aucs[0]) + cov = np.asarray(cov) + var = float(cov[0, 0]) if cov.ndim == 2 else float(cov) + return auc, var + + +# ------------------------- +# Next-token inference helper +# ------------------------- + +def predict_next_token_logits( + model: torch.nn.Module, + head: torch.nn.Module, + loader: DataLoader, + device: torch.device, + show_progress: bool = False, + progress_desc: str = "Inference (next-token)", + return_probs: bool = True, +) -> np.ndarray: + """Predict per-cause next-token scores at baseline positions. + + Returns: + np.ndarray of shape (N, K) where K is number of diseases (causes). + + Notes: + - For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the + *first* time/bin (index 0) and drops the complement channel when present. + - If return_probs=True, applies softmax over causes for probability-like scores. + """ + model.eval() + head.eval() + + all_out: List[np.ndarray] = [] + with torch.no_grad(): + for batch in _progress( + loader, + enabled=show_progress, + desc=progress_desc, + total=len(loader) if hasattr(loader, "__len__") else None, + ): + event_seq, time_seq, cont, cate, sex, baseline_pos = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + cont = cont.to(device, non_blocking=True) + cate = cate.to(device, non_blocking=True) + sex = sex.to(device, non_blocking=True) + baseline_pos = baseline_pos.to(device, non_blocking=True) + + h = model(event_seq, time_seq, sex, cont, cate) + b_idx = torch.arange(h.size(0), device=device) + c = h[b_idx, baseline_pos] + logits = head(c) + + # logits can be (B, K) or (B, K, T) or (B, K+1, T) + if logits.ndim == 2: + cause_logits = logits + elif logits.ndim == 3: + # Use the first time/bin. + cause_logits = logits[..., 0] + else: + raise ValueError( + f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}" + ) + + # If a complement/survival channel exists (discrete-time CIF), drop it. + if hasattr(model, "n_disease"): + n_disease = int(getattr(model, "n_disease")) + if cause_logits.size(1) == n_disease + 1: + cause_logits = cause_logits[:, :n_disease] + elif cause_logits.size(1) > n_disease: + cause_logits = cause_logits[:, :n_disease] + + if return_probs: + scores = torch.softmax(cause_logits, dim=1) + else: + scores = cause_logits + + all_out.append(scores.detach().cpu().numpy()) + + return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))