Refactor next-event evaluation to use next-token scores and implement clean control AUC metrics

This commit is contained in:
2026-01-17 23:34:19 +08:00
parent e56068e668
commit 4686b56336
2 changed files with 286 additions and 167 deletions

View File

@@ -23,20 +23,18 @@ from utils import (
load_checkpoint_into, load_checkpoint_into,
load_train_config, load_train_config,
parse_float_list, parse_float_list,
predict_cifs, predict_next_token_logits,
roc_auc_ovr, get_auc_delong_var,
seed_everything, seed_everything,
topk_indices,
DAYS_PER_YEAR, DAYS_PER_YEAR,
) )
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF" description="Evaluate next-event prediction using next-token scores"
) )
p.add_argument("--run_dir", type=str, required=True) p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument( p.add_argument(
"--age_bins", "--age_bins",
type=str, type=str,
@@ -73,140 +71,69 @@ def _format_age_bin_label(lo: float, hi: float) -> str:
return f"[{lo}, {hi})" return f"[{lo}, {hi})"
def _compute_next_event_metrics( def _compute_next_event_auc_clean_control(
*,
scores: np.ndarray,
y_next: np.ndarray,
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event *primary* metrics on a given subset.
Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
_ = min_pos
return metrics_rows, diag_df
def _compute_within_window_auc(
*, *,
scores: np.ndarray, scores: np.ndarray,
records: list, records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC. """Delphi-2M next-event AUC (clean control) per cause.
Label definition (event-driven, approximate; no IPCW): Definitions per cause k:
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short]. - Case: next_event_cause == k
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
AUC is computed with DeLong variance.
""" """
n_records = int(len(records)) n_records = int(len(records))
if n_records == 0: if n_records == 0:
return pd.DataFrame( return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"], columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
) )
K = int(scores.shape[1]) K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8) y_next = np.array(
tau = float(tau_short) [(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
# Build labels from future events. dtype=np.int64,
for i, r in enumerate(records): )
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64) auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) var = np.full((K,), np.nan, dtype=np.float64)
for k in candidates: n_case = np.zeros((K,), dtype=np.int64)
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32), n_control = np.zeros((K,), dtype=np.int64)
scores[:, k].astype(np.float64))
for k in range(K):
case_mask = y_next == k
if not np.any(case_mask):
continue
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
clean = np.fromiter(
((k not in rec.lifetime_causes) for rec in records),
dtype=bool,
count=n_records,
)
control_mask = control_mask & clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
n_case[k] = int(cs.size)
n_control[k] = int(hs.size)
if cs.size == 0 or hs.size == 0:
continue
a, v = get_auc_delong_var(hs, cs)
auc[k] = float(a)
var[k] = float(v)
included = (n_pos >= int(min_pos)) & (n_neg > 0)
return pd.DataFrame( return pd.DataFrame(
{ {
"cause_id": np.arange(K, dtype=np.int64), "cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos, "n_case": n_case,
"n_neg": n_neg, "n_control": n_control,
"auc": auc, "auc": auc,
"included": included, "auc_variance": var,
} }
) )
@@ -247,20 +174,15 @@ def main() -> None:
**dl_kwargs, **dl_kwargs,
) )
tau = float(args.tau_short) scores = predict_next_token_logits(
scores = predict_cifs(
model, model,
head, head,
criterion,
loader, loader,
[tau],
device=device, device=device,
show_progress=show_progress, show_progress=show_progress,
progress_desc="Inference (next-event)", progress_desc="Inference (next-token)",
return_probs=True,
) )
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
y_next = np.array( y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause)) [(-1 if r.next_event_cause is None else int(r.next_event_cause))
@@ -284,24 +206,24 @@ def main() -> None:
label = _format_age_bin_label(lo, hi) label = _format_age_bin_label(lo, hi)
m = bin_idx == b m = bin_idx == b
m_scores = scores[m] m_scores = scores[m]
m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)] m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics( # Coverage metric for transparency (not Delphi-2M AUC itself).
scores=m_scores, m_y = y_next[m]
y_next=m_y, n_total = int(m_y.size)
tau_short=tau, n_eligible = int((m_y >= 0).sum())
min_pos=int(args.min_pos), coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
) per_bin_metric_rows.append(
for row in m_rows: {"age_bin": label, "metric": "n_records_total", "value": n_total})
per_bin_metric_rows.append({"age_bin": label, **row}) per_bin_metric_rows.append(
m_auc = _compute_within_window_auc( {"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "coverage", "value": coverage})
m_auc = _compute_next_event_auc_clean_control(
scores=m_scores, scores=m_scores,
records=m_records, records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
) )
m_auc.insert(0, "age_bin", label) m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc) per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join( out_metrics_bins = os.path.join(
@@ -313,11 +235,11 @@ def main() -> None:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv( pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False) out_auc_bins, index=False)
else: else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos", pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False) "auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.") print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).") print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
print(f"Wrote {out_metrics_bins}") print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}") print(f"Wrote {out_auc_bins}")

241
utils.py
View File

@@ -273,6 +273,8 @@ class EvalRecord:
cutoff_pos: int # baseline position (inclusive) cutoff_pos: int # baseline position (inclusive)
next_event_cause: Optional[int] next_event_cause: Optional[int]
next_event_dt_years: Optional[float] next_event_dt_years: Optional[float]
# (U,) unique causes ever observed (clean-control filtering)
lifetime_causes: np.ndarray
future_causes: np.ndarray # (E,) in [0..K-1] future_causes: np.ndarray # (E,) in [0..K-1]
future_dt_years: np.ndarray # (E,) strictly > 0 future_dt_years: np.ndarray # (E,) strictly > 0
@@ -320,7 +322,19 @@ def build_event_driven_records(
doa_days = float(times_ins[int(doa_pos[0])]) doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS is_disease = codes_ins >= N_TECH_TOKENS
disease_times = times_ins[is_disease]
# Lifetime (ever) disease history for Clean Control filtering.
if np.any(is_disease):
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
np.int64, copy=False
)
lifetime_causes = np.unique(lifetime_causes)
else:
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
(0,), dtype=np.float64)
for b in range(len(age_bins_days) - 1): for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b] lo = age_bins_days[b]
@@ -331,29 +345,22 @@ def build_event_driven_records(
if not (doa_days <= hi): if not (doa_days <= hi):
continue continue
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA.
in_bin = (disease_times >= lo) & ( # Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin).
disease_times < hi) & (disease_times >= doa_days) if disease_pos_all.size == 0:
cand_times = disease_times[in_bin]
if cand_times.size == 0:
continue continue
t0_days = float(rng.choice(cand_times)) in_bin = (
(disease_times_all >= lo)
& (disease_times_all < hi)
& (disease_times_all >= doa_days)
)
cand_pos = disease_pos_all[in_bin]
if cand_pos.size == 0:
continue
# Baseline position (inclusive) in the *post-DOA-inserted* sequence. cutoff_pos = int(rng.choice(cand_pos))
pos = np.flatnonzero(is_disease & np.isclose( t0_days = float(times_ins[cutoff_pos])
times_ins, t0_days, rtol=0.0, atol=eps))
if pos.size == 0:
disease_pos = np.flatnonzero(is_disease)
if disease_pos.size == 0:
continue
disease_times_full = times_ins[disease_pos]
closest_idx = int(
np.argmin(np.abs(disease_times_full - t0_days)))
cutoff_pos = int(disease_pos[closest_idx])
t0_days = float(disease_times_full[closest_idx])
else:
cutoff_pos = int(pos[0])
# Future disease events strictly after t0 # Future disease events strictly after t0
future_mask = (times_ins > (t0_days + eps)) & is_disease future_mask = (times_ins > (t0_days + eps)) & is_disease
@@ -366,7 +373,8 @@ def build_event_driven_records(
else: else:
future_times_days = times_ins[future_pos] future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos] future_tokens = codes_ins[future_pos]
future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64) future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = ( future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32) (future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
@@ -383,6 +391,7 @@ def build_event_driven_records(
cutoff_pos=int(cutoff_pos), cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause, next_event_cause=next_cause,
next_event_dt_years=next_dt_years, next_event_dt_years=next_dt_years,
lifetime_causes=lifetime_causes,
future_causes=future_causes, future_causes=future_causes,
future_dt_years=future_dt_years_arr, future_dt_years=future_dt_years_arr,
) )
@@ -575,3 +584,191 @@ def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
part_scores = np.take_along_axis(scores, part, axis=1) part_scores = np.take_along_axis(scores, part, axis=1)
order = np.argsort(-part_scores, axis=1, kind="mergesort") order = np.argsort(-part_scores, axis=1, kind="mergesort")
return np.take_along_axis(part, order, axis=1) return np.take_along_axis(part, order, axis=1)
# -------------------------
# Statistical evaluation (DeLong)
# -------------------------
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Compute midranks of a 1D array (1-based ranks, tie-aware)."""
x = np.asarray(x, dtype=np.float64)
if x.ndim != 1:
raise ValueError("compute_midrank expects a 1D array")
order = np.argsort(x, kind="mergesort")
x_sorted = x[order]
n = int(x_sorted.size)
midranks = np.empty((n,), dtype=np.float64)
i = 0
while i < n:
j = i
while j < n and x_sorted[j] == x_sorted[i]:
j += 1
# ranks are 1..n; average over ties
mid = 0.5 * ((i + 1) + j)
midranks[i:j] = mid
i = j
out = np.empty((n,), dtype=np.float64)
out[order] = midranks
return out
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fast DeLong method for AUC covariance.
Args:
predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first
label_1_count examples are positives.
label_1_count: number of positive examples.
Returns:
(aucs, delong_cov)
"""
preds = np.asarray(predictions_sorted_transposed, dtype=np.float64)
if preds.ndim != 2:
raise ValueError("predictions_sorted_transposed must be 2D")
m = int(label_1_count)
n = int(preds.shape[1] - m)
if m <= 0 or n <= 0:
raise ValueError("DeLong requires at least 1 positive and 1 negative")
k = int(preds.shape[0])
tx = np.empty((k, m), dtype=np.float64)
ty = np.empty((k, n), dtype=np.float64)
tz = np.empty((k, m + n), dtype=np.float64)
for r in range(k):
tx[r] = compute_midrank(preds[r, :m])
ty[r] = compute_midrank(preds[r, m:])
tz[r] = compute_midrank(preds[r, :])
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
v01 = (tz[:, :m] - tx) / float(n)
v10 = 1.0 - (tz[:, m:] - ty) / float(m)
# np.cov expects variables in rows by default when rowvar=True.
sx = np.cov(v01, rowvar=True, bias=False)
sy = np.cov(v10, rowvar=True, bias=False)
delong_cov = sx / float(m) + sy / float(n)
return aucs, delong_cov
def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]:
"""Return ordering that places positives first and label_1_count."""
y = np.asarray(ground_truth, dtype=np.int32)
if y.ndim != 1:
raise ValueError("ground_truth must be 1D")
label_1_count = int(y.sum())
order = np.argsort(-y, kind="mergesort")
return order, label_1_count
def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]:
"""Compute AUC and its DeLong variance.
Args:
healthy_scores: scores for controls (label=0)
diseased_scores: scores for cases (label=1)
Returns:
(auc, auc_variance)
"""
h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1)
d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1)
n0 = int(h.size)
n1 = int(d.size)
if n0 == 0 or n1 == 0:
return float("nan"), float("nan")
# Arrange positives first as required by fastDeLong.
scores = np.concatenate([d, h], axis=0)
gt = np.concatenate([
np.ones((n1,), dtype=np.int32),
np.zeros((n0,), dtype=np.int32),
])
order, label_1_count = compute_ground_truth_statistics(gt)
preds_sorted = scores[order][None, :]
aucs, cov = fastDeLong(preds_sorted, label_1_count)
auc = float(aucs[0])
cov = np.asarray(cov)
var = float(cov[0, 0]) if cov.ndim == 2 else float(cov)
return auc, var
# -------------------------
# Next-token inference helper
# -------------------------
def predict_next_token_logits(
model: torch.nn.Module,
head: torch.nn.Module,
loader: DataLoader,
device: torch.device,
show_progress: bool = False,
progress_desc: str = "Inference (next-token)",
return_probs: bool = True,
) -> np.ndarray:
"""Predict per-cause next-token scores at baseline positions.
Returns:
np.ndarray of shape (N, K) where K is number of diseases (causes).
Notes:
- For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the
*first* time/bin (index 0) and drops the complement channel when present.
- If return_probs=True, applies softmax over causes for probability-like scores.
"""
model.eval()
head.eval()
all_out: List[np.ndarray] = []
with torch.no_grad():
for batch in _progress(
loader,
enabled=show_progress,
desc=progress_desc,
total=len(loader) if hasattr(loader, "__len__") else None,
):
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
cont = cont.to(device, non_blocking=True)
cate = cate.to(device, non_blocking=True)
sex = sex.to(device, non_blocking=True)
baseline_pos = baseline_pos.to(device, non_blocking=True)
h = model(event_seq, time_seq, sex, cont, cate)
b_idx = torch.arange(h.size(0), device=device)
c = h[b_idx, baseline_pos]
logits = head(c)
# logits can be (B, K) or (B, K, T) or (B, K+1, T)
if logits.ndim == 2:
cause_logits = logits
elif logits.ndim == 3:
# Use the first time/bin.
cause_logits = logits[..., 0]
else:
raise ValueError(
f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}"
)
# If a complement/survival channel exists (discrete-time CIF), drop it.
if hasattr(model, "n_disease"):
n_disease = int(getattr(model, "n_disease"))
if cause_logits.size(1) == n_disease + 1:
cause_logits = cause_logits[:, :n_disease]
elif cause_logits.size(1) > n_disease:
cause_logits = cause_logits[:, :n_disease]
if return_probs:
scores = torch.softmax(cause_logits, dim=1)
else:
scores = cause_logits
all_out.append(scores.detach().cpu().numpy())
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))