Refactor next-event evaluation to use next-token scores and implement clean control AUC metrics
This commit is contained in:
241
utils.py
241
utils.py
@@ -273,6 +273,8 @@ class EvalRecord:
|
||||
cutoff_pos: int # baseline position (inclusive)
|
||||
next_event_cause: Optional[int]
|
||||
next_event_dt_years: Optional[float]
|
||||
# (U,) unique causes ever observed (clean-control filtering)
|
||||
lifetime_causes: np.ndarray
|
||||
future_causes: np.ndarray # (E,) in [0..K-1]
|
||||
future_dt_years: np.ndarray # (E,) strictly > 0
|
||||
|
||||
@@ -320,7 +322,19 @@ def build_event_driven_records(
|
||||
doa_days = float(times_ins[int(doa_pos[0])])
|
||||
|
||||
is_disease = codes_ins >= N_TECH_TOKENS
|
||||
disease_times = times_ins[is_disease]
|
||||
|
||||
# Lifetime (ever) disease history for Clean Control filtering.
|
||||
if np.any(is_disease):
|
||||
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
||||
np.int64, copy=False
|
||||
)
|
||||
lifetime_causes = np.unique(lifetime_causes)
|
||||
else:
|
||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
disease_pos_all = np.flatnonzero(is_disease)
|
||||
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
|
||||
(0,), dtype=np.float64)
|
||||
|
||||
for b in range(len(age_bins_days) - 1):
|
||||
lo = age_bins_days[b]
|
||||
@@ -331,29 +345,22 @@ def build_event_driven_records(
|
||||
if not (doa_days <= hi):
|
||||
continue
|
||||
|
||||
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA
|
||||
in_bin = (disease_times >= lo) & (
|
||||
disease_times < hi) & (disease_times >= doa_days)
|
||||
cand_times = disease_times[in_bin]
|
||||
if cand_times.size == 0:
|
||||
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA.
|
||||
# Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin).
|
||||
if disease_pos_all.size == 0:
|
||||
continue
|
||||
|
||||
t0_days = float(rng.choice(cand_times))
|
||||
in_bin = (
|
||||
(disease_times_all >= lo)
|
||||
& (disease_times_all < hi)
|
||||
& (disease_times_all >= doa_days)
|
||||
)
|
||||
cand_pos = disease_pos_all[in_bin]
|
||||
if cand_pos.size == 0:
|
||||
continue
|
||||
|
||||
# Baseline position (inclusive) in the *post-DOA-inserted* sequence.
|
||||
pos = np.flatnonzero(is_disease & np.isclose(
|
||||
times_ins, t0_days, rtol=0.0, atol=eps))
|
||||
if pos.size == 0:
|
||||
disease_pos = np.flatnonzero(is_disease)
|
||||
if disease_pos.size == 0:
|
||||
continue
|
||||
disease_times_full = times_ins[disease_pos]
|
||||
closest_idx = int(
|
||||
np.argmin(np.abs(disease_times_full - t0_days)))
|
||||
cutoff_pos = int(disease_pos[closest_idx])
|
||||
t0_days = float(disease_times_full[closest_idx])
|
||||
else:
|
||||
cutoff_pos = int(pos[0])
|
||||
cutoff_pos = int(rng.choice(cand_pos))
|
||||
t0_days = float(times_ins[cutoff_pos])
|
||||
|
||||
# Future disease events strictly after t0
|
||||
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
||||
@@ -366,7 +373,8 @@ def build_event_driven_records(
|
||||
else:
|
||||
future_times_days = times_ins[future_pos]
|
||||
future_tokens = codes_ins[future_pos]
|
||||
future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_causes = (
|
||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_dt_years_arr = (
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
|
||||
|
||||
@@ -383,6 +391,7 @@ def build_event_driven_records(
|
||||
cutoff_pos=int(cutoff_pos),
|
||||
next_event_cause=next_cause,
|
||||
next_event_dt_years=next_dt_years,
|
||||
lifetime_causes=lifetime_causes,
|
||||
future_causes=future_causes,
|
||||
future_dt_years=future_dt_years_arr,
|
||||
)
|
||||
@@ -575,3 +584,191 @@ def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
|
||||
part_scores = np.take_along_axis(scores, part, axis=1)
|
||||
order = np.argsort(-part_scores, axis=1, kind="mergesort")
|
||||
return np.take_along_axis(part, order, axis=1)
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Statistical evaluation (DeLong)
|
||||
# -------------------------
|
||||
|
||||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||||
"""Compute midranks of a 1D array (1-based ranks, tie-aware)."""
|
||||
x = np.asarray(x, dtype=np.float64)
|
||||
if x.ndim != 1:
|
||||
raise ValueError("compute_midrank expects a 1D array")
|
||||
|
||||
order = np.argsort(x, kind="mergesort")
|
||||
x_sorted = x[order]
|
||||
n = int(x_sorted.size)
|
||||
|
||||
midranks = np.empty((n,), dtype=np.float64)
|
||||
i = 0
|
||||
while i < n:
|
||||
j = i
|
||||
while j < n and x_sorted[j] == x_sorted[i]:
|
||||
j += 1
|
||||
# ranks are 1..n; average over ties
|
||||
mid = 0.5 * ((i + 1) + j)
|
||||
midranks[i:j] = mid
|
||||
i = j
|
||||
|
||||
out = np.empty((n,), dtype=np.float64)
|
||||
out[order] = midranks
|
||||
return out
|
||||
|
||||
|
||||
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Fast DeLong method for AUC covariance.
|
||||
|
||||
Args:
|
||||
predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first
|
||||
label_1_count examples are positives.
|
||||
label_1_count: number of positive examples.
|
||||
Returns:
|
||||
(aucs, delong_cov)
|
||||
"""
|
||||
preds = np.asarray(predictions_sorted_transposed, dtype=np.float64)
|
||||
if preds.ndim != 2:
|
||||
raise ValueError("predictions_sorted_transposed must be 2D")
|
||||
|
||||
m = int(label_1_count)
|
||||
n = int(preds.shape[1] - m)
|
||||
if m <= 0 or n <= 0:
|
||||
raise ValueError("DeLong requires at least 1 positive and 1 negative")
|
||||
|
||||
k = int(preds.shape[0])
|
||||
tx = np.empty((k, m), dtype=np.float64)
|
||||
ty = np.empty((k, n), dtype=np.float64)
|
||||
tz = np.empty((k, m + n), dtype=np.float64)
|
||||
|
||||
for r in range(k):
|
||||
tx[r] = compute_midrank(preds[r, :m])
|
||||
ty[r] = compute_midrank(preds[r, m:])
|
||||
tz[r] = compute_midrank(preds[r, :])
|
||||
|
||||
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
|
||||
|
||||
v01 = (tz[:, :m] - tx) / float(n)
|
||||
v10 = 1.0 - (tz[:, m:] - ty) / float(m)
|
||||
|
||||
# np.cov expects variables in rows by default when rowvar=True.
|
||||
sx = np.cov(v01, rowvar=True, bias=False)
|
||||
sy = np.cov(v10, rowvar=True, bias=False)
|
||||
delong_cov = sx / float(m) + sy / float(n)
|
||||
return aucs, delong_cov
|
||||
|
||||
|
||||
def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]:
|
||||
"""Return ordering that places positives first and label_1_count."""
|
||||
y = np.asarray(ground_truth, dtype=np.int32)
|
||||
if y.ndim != 1:
|
||||
raise ValueError("ground_truth must be 1D")
|
||||
label_1_count = int(y.sum())
|
||||
order = np.argsort(-y, kind="mergesort")
|
||||
return order, label_1_count
|
||||
|
||||
|
||||
def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]:
|
||||
"""Compute AUC and its DeLong variance.
|
||||
|
||||
Args:
|
||||
healthy_scores: scores for controls (label=0)
|
||||
diseased_scores: scores for cases (label=1)
|
||||
Returns:
|
||||
(auc, auc_variance)
|
||||
"""
|
||||
h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1)
|
||||
d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1)
|
||||
n0 = int(h.size)
|
||||
n1 = int(d.size)
|
||||
if n0 == 0 or n1 == 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
# Arrange positives first as required by fastDeLong.
|
||||
scores = np.concatenate([d, h], axis=0)
|
||||
gt = np.concatenate([
|
||||
np.ones((n1,), dtype=np.int32),
|
||||
np.zeros((n0,), dtype=np.int32),
|
||||
])
|
||||
order, label_1_count = compute_ground_truth_statistics(gt)
|
||||
preds_sorted = scores[order][None, :]
|
||||
aucs, cov = fastDeLong(preds_sorted, label_1_count)
|
||||
auc = float(aucs[0])
|
||||
cov = np.asarray(cov)
|
||||
var = float(cov[0, 0]) if cov.ndim == 2 else float(cov)
|
||||
return auc, var
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Next-token inference helper
|
||||
# -------------------------
|
||||
|
||||
def predict_next_token_logits(
|
||||
model: torch.nn.Module,
|
||||
head: torch.nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
show_progress: bool = False,
|
||||
progress_desc: str = "Inference (next-token)",
|
||||
return_probs: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Predict per-cause next-token scores at baseline positions.
|
||||
|
||||
Returns:
|
||||
np.ndarray of shape (N, K) where K is number of diseases (causes).
|
||||
|
||||
Notes:
|
||||
- For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the
|
||||
*first* time/bin (index 0) and drops the complement channel when present.
|
||||
- If return_probs=True, applies softmax over causes for probability-like scores.
|
||||
"""
|
||||
model.eval()
|
||||
head.eval()
|
||||
|
||||
all_out: List[np.ndarray] = []
|
||||
with torch.no_grad():
|
||||
for batch in _progress(
|
||||
loader,
|
||||
enabled=show_progress,
|
||||
desc=progress_desc,
|
||||
total=len(loader) if hasattr(loader, "__len__") else None,
|
||||
):
|
||||
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
||||
event_seq = event_seq.to(device, non_blocking=True)
|
||||
time_seq = time_seq.to(device, non_blocking=True)
|
||||
cont = cont.to(device, non_blocking=True)
|
||||
cate = cate.to(device, non_blocking=True)
|
||||
sex = sex.to(device, non_blocking=True)
|
||||
baseline_pos = baseline_pos.to(device, non_blocking=True)
|
||||
|
||||
h = model(event_seq, time_seq, sex, cont, cate)
|
||||
b_idx = torch.arange(h.size(0), device=device)
|
||||
c = h[b_idx, baseline_pos]
|
||||
logits = head(c)
|
||||
|
||||
# logits can be (B, K) or (B, K, T) or (B, K+1, T)
|
||||
if logits.ndim == 2:
|
||||
cause_logits = logits
|
||||
elif logits.ndim == 3:
|
||||
# Use the first time/bin.
|
||||
cause_logits = logits[..., 0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}"
|
||||
)
|
||||
|
||||
# If a complement/survival channel exists (discrete-time CIF), drop it.
|
||||
if hasattr(model, "n_disease"):
|
||||
n_disease = int(getattr(model, "n_disease"))
|
||||
if cause_logits.size(1) == n_disease + 1:
|
||||
cause_logits = cause_logits[:, :n_disease]
|
||||
elif cause_logits.size(1) > n_disease:
|
||||
cause_logits = cause_logits[:, :n_disease]
|
||||
|
||||
if return_probs:
|
||||
scores = torch.softmax(cause_logits, dim=1)
|
||||
else:
|
||||
scores = cause_logits
|
||||
|
||||
all_out.append(scores.detach().cpu().numpy())
|
||||
|
||||
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
|
||||
|
||||
Reference in New Issue
Block a user