Refactor tqdm import handling and improve context sampling in utils.py

This commit is contained in:
2026-01-16 16:57:35 +08:00
parent e47a7ce4d6
commit b1647d1b74
2 changed files with 53 additions and 34 deletions

View File

@@ -8,9 +8,12 @@ import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
@@ -281,7 +284,7 @@ class EvalAgeConfig:
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
@torch.inference_mode()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
@@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of arrays
y_true: List[List[List[List[np.ndarray]]]] = [
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
y_true: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[np.ndarray]]]] = [
y_pred: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
@@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins(
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
# within this batch, so this is safe and numerically identical.
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
for tau_idx, tau_y in enumerate(horizons_years):
tau_tensor = torch.tensor(float(tau_y), device=device)
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
@@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins(
if not keep.any():
continue
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
# Bin-specific prediction: context indices differ per (tau, bin)
# but the backbone features do not.
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=torch.tensor(float(tau_y), device=device)
logits, taus=tau_tensor
)
if cifs.ndim != 2:
raise ValueError(
@@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins(
)
preds = cifs.index_select(dim=1, index=cause_ids)
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
# and convert to NumPy once during aggregation.
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(torch.bool).cpu().numpy()
y[keep].detach().to(dtype=torch.bool,
device="cpu", non_blocking=True)
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(torch.float32).cpu().numpy()
preds[keep].detach().to(dtype=torch.float32,
device="cpu", non_blocking=True)
)
rows_by_bin: List[Dict[str, float | int]] = []
@@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins(
)
continue
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
if yb.shape != pb.shape:
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
yb = yb_t.numpy()
pb = pb_t.numpy()
n_samples = int(yb.shape[0])
for cause_k in range(n_causes_eval):

View File

@@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin(
eligible = valid & in_bin & (
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
keep = torch.zeros((B,), dtype=torch.bool, device=device)
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
# Vectorized, uniform sampling over eligible indices per sample.
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
# choice among eligible indices by symmetry (ties have probability ~0).
keep = eligible.any(dim=1)
gen = torch.Generator(device="cpu")
# Prefer a per-call generator on the target device for reproducibility without
# touching global RNG state. If unavailable, fall back to seeding the global
# CUDA RNG for this call.
gen = None
if device.type == "cuda":
try:
gen = torch.Generator(device=device)
gen.manual_seed(int(seed))
except Exception:
gen = None
torch.cuda.manual_seed(int(seed))
else:
gen = torch.Generator()
gen.manual_seed(int(seed))
for i in range(B):
m = eligible[i]
if not m.any():
continue
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
chosen_idx_pos = int(
torch.randint(low=0, high=int(idxs.numel()),
size=(1,), generator=gen).item()
)
chosen_t = int(idxs[chosen_idx_pos].item())
keep[i] = True
t_ctx[i] = chosen_t
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
r = r.masked_fill(~eligible, -1.0)
t_ctx = r.argmax(dim=1).to(torch.long)
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
return keep, t_ctx