From b1647d1b747399759dcaf6424d63f3ad685fef25 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 16 Jan 2026 16:57:35 +0800 Subject: [PATCH] Refactor tqdm import handling and improve context sampling in utils.py --- evaluation_age_time_dependent.py | 47 +++++++++++++++++++++----------- utils.py | 40 +++++++++++++++------------ 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py index d3f8c3e..f788b2a 100644 --- a/evaluation_age_time_dependent.py +++ b/evaluation_age_time_dependent.py @@ -8,9 +8,12 @@ import numpy as np import pandas as pd import torch -from tqdm import tqdm - +try: + from tqdm import tqdm +except Exception: # pragma: no cover + def tqdm(x, **kwargs): + return x from utils import ( multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, @@ -281,7 +284,7 @@ class EvalAgeConfig: cause_ids: Optional[Sequence[int]] = None -@torch.no_grad() +@torch.inference_mode() def evaluate_time_dependent_age_bins( model: torch.nn.Module, head: torch.nn.Module, @@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) - # Storage: (mc, h, bin) -> list of arrays - y_true: List[List[List[List[np.ndarray]]]] = [ + # Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops) + y_true: List[List[List[List[torch.Tensor]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] - y_pred: List[List[List[List[np.ndarray]]]] = [ + y_pred: List[List[List[List[torch.Tensor]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] @@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins( B = int(event_seq.size(0)) b = torch.arange(B, device=device) + # Hoist backbone forward pass: inputs are identical across (tau, age_bin) + # within this batch, so this is safe and numerically identical. + h = model(event_seq, time_seq, sexes, + cont_feats, cate_feats) # (B,L,D) + for tau_idx, tau_y in enumerate(horizons_years): + tau_tensor = torch.tensor(float(tau_y), device=device) for bin_idx, (a_lo, a_hi) in enumerate(age_bins): # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( @@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins( if not keep.any(): continue - # Strict bin-specific prediction: recompute representations and logits per (tau, bin). - h = model(event_seq, time_seq, sexes, - cont_feats, cate_feats) # (B,L,D) + # Bin-specific prediction: context indices differ per (tau, bin) + # but the backbone features do not. c = h[b, t_ctx] logits = head(c) cifs = criterion.calculate_cifs( - logits, taus=torch.tensor(float(tau_y), device=device) + logits, taus=tau_tensor ) if cifs.ndim != 2: raise ValueError( @@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins( ) preds = cifs.index_select(dim=1, index=cause_ids) + # Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors + # and convert to NumPy once during aggregation. y_true[mc_idx][tau_idx][bin_idx].append( - y[keep].detach().to(torch.bool).cpu().numpy() + y[keep].detach().to(dtype=torch.bool, + device="cpu", non_blocking=True) ) y_pred[mc_idx][tau_idx][bin_idx].append( - preds[keep].detach().to(torch.float32).cpu().numpy() + preds[keep].detach().to(dtype=torch.float32, + device="cpu", non_blocking=True) ) rows_by_bin: List[Dict[str, float | int]] = [] @@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins( ) continue - yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0) - pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0) - if yb.shape != pb.shape: + yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0) + pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0) + if tuple(yb_t.shape) != tuple(pb_t.shape): raise ValueError( - f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}" + f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}" ) + yb = yb_t.numpy() + pb = pb_t.numpy() + n_samples = int(yb.shape[0]) for cause_k in range(n_causes_eval): diff --git a/utils.py b/utils.py index 0aded3b..3631574 100644 --- a/utils.py +++ b/utils.py @@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin( eligible = valid & in_bin & ( (time_seq + tau_days) <= followup_end_time.unsqueeze(1)) - keep = torch.zeros((B,), dtype=torch.bool, device=device) - t_ctx = torch.zeros((B,), dtype=torch.long, device=device) + # Vectorized, uniform sampling over eligible indices per sample. + # Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform + # choice among eligible indices by symmetry (ties have probability ~0). + keep = eligible.any(dim=1) - gen = torch.Generator(device="cpu") - gen.manual_seed(int(seed)) + # Prefer a per-call generator on the target device for reproducibility without + # touching global RNG state. If unavailable, fall back to seeding the global + # CUDA RNG for this call. + gen = None + if device.type == "cuda": + try: + gen = torch.Generator(device=device) + gen.manual_seed(int(seed)) + except Exception: + gen = None + torch.cuda.manual_seed(int(seed)) + else: + gen = torch.Generator() + gen.manual_seed(int(seed)) - for i in range(B): - m = eligible[i] - if not m.any(): - continue - - idxs = m.nonzero(as_tuple=False).view(-1).cpu() - chosen_idx_pos = int( - torch.randint(low=0, high=int(idxs.numel()), - size=(1,), generator=gen).item() - ) - chosen_t = int(idxs[chosen_idx_pos].item()) - - keep[i] = True - t_ctx[i] = chosen_t + r = torch.rand((B, eligible.size(1)), device=device, generator=gen) + r = r.masked_fill(~eligible, -1.0) + t_ctx = r.argmax(dim=1).to(torch.long) + # When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0). return keep, t_ctx