Refactor tqdm import handling and improve context sampling in utils.py
This commit is contained in:
@@ -8,9 +8,12 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
def tqdm(x, **kwargs):
|
||||
return x
|
||||
from utils import (
|
||||
multi_hot_ever_within_horizon,
|
||||
multi_hot_selected_causes_within_horizon,
|
||||
@@ -281,7 +284,7 @@ class EvalAgeConfig:
|
||||
cause_ids: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def evaluate_time_dependent_age_bins(
|
||||
model: torch.nn.Module,
|
||||
head: torch.nn.Module,
|
||||
@@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins(
|
||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||
n_causes_eval = int(cause_ids.numel())
|
||||
|
||||
# Storage: (mc, h, bin) -> list of arrays
|
||||
y_true: List[List[List[List[np.ndarray]]]] = [
|
||||
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
|
||||
y_true: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
y_pred: List[List[List[List[np.ndarray]]]] = [
|
||||
y_pred: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
@@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins(
|
||||
B = int(event_seq.size(0))
|
||||
b = torch.arange(B, device=device)
|
||||
|
||||
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
|
||||
# within this batch, so this is safe and numerically identical.
|
||||
h = model(event_seq, time_seq, sexes,
|
||||
cont_feats, cate_feats) # (B,L,D)
|
||||
|
||||
for tau_idx, tau_y in enumerate(horizons_years):
|
||||
tau_tensor = torch.tensor(float(tau_y), device=device)
|
||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||
seed = (
|
||||
@@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins(
|
||||
if not keep.any():
|
||||
continue
|
||||
|
||||
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
|
||||
h = model(event_seq, time_seq, sexes,
|
||||
cont_feats, cate_feats) # (B,L,D)
|
||||
# Bin-specific prediction: context indices differ per (tau, bin)
|
||||
# but the backbone features do not.
|
||||
c = h[b, t_ctx]
|
||||
logits = head(c)
|
||||
|
||||
cifs = criterion.calculate_cifs(
|
||||
logits, taus=torch.tensor(float(tau_y), device=device)
|
||||
logits, taus=tau_tensor
|
||||
)
|
||||
if cifs.ndim != 2:
|
||||
raise ValueError(
|
||||
@@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
preds = cifs.index_select(dim=1, index=cause_ids)
|
||||
|
||||
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
|
||||
# and convert to NumPy once during aggregation.
|
||||
y_true[mc_idx][tau_idx][bin_idx].append(
|
||||
y[keep].detach().to(torch.bool).cpu().numpy()
|
||||
y[keep].detach().to(dtype=torch.bool,
|
||||
device="cpu", non_blocking=True)
|
||||
)
|
||||
y_pred[mc_idx][tau_idx][bin_idx].append(
|
||||
preds[keep].detach().to(torch.float32).cpu().numpy()
|
||||
preds[keep].detach().to(dtype=torch.float32,
|
||||
device="cpu", non_blocking=True)
|
||||
)
|
||||
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
@@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
continue
|
||||
|
||||
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
|
||||
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
|
||||
if yb.shape != pb.shape:
|
||||
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
|
||||
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
|
||||
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
||||
raise ValueError(
|
||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
|
||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
||||
)
|
||||
|
||||
yb = yb_t.numpy()
|
||||
pb = pb_t.numpy()
|
||||
|
||||
n_samples = int(yb.shape[0])
|
||||
|
||||
for cause_k in range(n_causes_eval):
|
||||
|
||||
40
utils.py
40
utils.py
@@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin(
|
||||
eligible = valid & in_bin & (
|
||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||
|
||||
keep = torch.zeros((B,), dtype=torch.bool, device=device)
|
||||
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
|
||||
# Vectorized, uniform sampling over eligible indices per sample.
|
||||
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
|
||||
# choice among eligible indices by symmetry (ties have probability ~0).
|
||||
keep = eligible.any(dim=1)
|
||||
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(int(seed))
|
||||
# Prefer a per-call generator on the target device for reproducibility without
|
||||
# touching global RNG state. If unavailable, fall back to seeding the global
|
||||
# CUDA RNG for this call.
|
||||
gen = None
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(int(seed))
|
||||
except Exception:
|
||||
gen = None
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
else:
|
||||
gen = torch.Generator()
|
||||
gen.manual_seed(int(seed))
|
||||
|
||||
for i in range(B):
|
||||
m = eligible[i]
|
||||
if not m.any():
|
||||
continue
|
||||
|
||||
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
|
||||
chosen_idx_pos = int(
|
||||
torch.randint(low=0, high=int(idxs.numel()),
|
||||
size=(1,), generator=gen).item()
|
||||
)
|
||||
chosen_t = int(idxs[chosen_idx_pos].item())
|
||||
|
||||
keep[i] = True
|
||||
t_ctx[i] = chosen_t
|
||||
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
|
||||
r = r.masked_fill(~eligible, -1.0)
|
||||
t_ctx = r.argmax(dim=1).to(torch.long)
|
||||
|
||||
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
|
||||
return keep, t_ctx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user