Refactor tqdm import handling and improve context sampling in utils.py
This commit is contained in:
@@ -8,9 +8,12 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
|
||||||
|
def tqdm(x, **kwargs):
|
||||||
|
return x
|
||||||
from utils import (
|
from utils import (
|
||||||
multi_hot_ever_within_horizon,
|
multi_hot_ever_within_horizon,
|
||||||
multi_hot_selected_causes_within_horizon,
|
multi_hot_selected_causes_within_horizon,
|
||||||
@@ -281,7 +284,7 @@ class EvalAgeConfig:
|
|||||||
cause_ids: Optional[Sequence[int]] = None
|
cause_ids: Optional[Sequence[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def evaluate_time_dependent_age_bins(
|
def evaluate_time_dependent_age_bins(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
head: torch.nn.Module,
|
head: torch.nn.Module,
|
||||||
@@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins(
|
|||||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||||
n_causes_eval = int(cause_ids.numel())
|
n_causes_eval = int(cause_ids.numel())
|
||||||
|
|
||||||
# Storage: (mc, h, bin) -> list of arrays
|
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
|
||||||
y_true: List[List[List[List[np.ndarray]]]] = [
|
y_true: List[List[List[List[torch.Tensor]]]] = [
|
||||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||||
for _ in range(int(cfg.n_mc))
|
for _ in range(int(cfg.n_mc))
|
||||||
]
|
]
|
||||||
y_pred: List[List[List[List[np.ndarray]]]] = [
|
y_pred: List[List[List[List[torch.Tensor]]]] = [
|
||||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||||
for _ in range(int(cfg.n_mc))
|
for _ in range(int(cfg.n_mc))
|
||||||
]
|
]
|
||||||
@@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins(
|
|||||||
B = int(event_seq.size(0))
|
B = int(event_seq.size(0))
|
||||||
b = torch.arange(B, device=device)
|
b = torch.arange(B, device=device)
|
||||||
|
|
||||||
|
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
|
||||||
|
# within this batch, so this is safe and numerically identical.
|
||||||
|
h = model(event_seq, time_seq, sexes,
|
||||||
|
cont_feats, cate_feats) # (B,L,D)
|
||||||
|
|
||||||
for tau_idx, tau_y in enumerate(horizons_years):
|
for tau_idx, tau_y in enumerate(horizons_years):
|
||||||
|
tau_tensor = torch.tensor(float(tau_y), device=device)
|
||||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||||
seed = (
|
seed = (
|
||||||
@@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins(
|
|||||||
if not keep.any():
|
if not keep.any():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
|
# Bin-specific prediction: context indices differ per (tau, bin)
|
||||||
h = model(event_seq, time_seq, sexes,
|
# but the backbone features do not.
|
||||||
cont_feats, cate_feats) # (B,L,D)
|
|
||||||
c = h[b, t_ctx]
|
c = h[b, t_ctx]
|
||||||
logits = head(c)
|
logits = head(c)
|
||||||
|
|
||||||
cifs = criterion.calculate_cifs(
|
cifs = criterion.calculate_cifs(
|
||||||
logits, taus=torch.tensor(float(tau_y), device=device)
|
logits, taus=tau_tensor
|
||||||
)
|
)
|
||||||
if cifs.ndim != 2:
|
if cifs.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins(
|
|||||||
)
|
)
|
||||||
preds = cifs.index_select(dim=1, index=cause_ids)
|
preds = cifs.index_select(dim=1, index=cause_ids)
|
||||||
|
|
||||||
|
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
|
||||||
|
# and convert to NumPy once during aggregation.
|
||||||
y_true[mc_idx][tau_idx][bin_idx].append(
|
y_true[mc_idx][tau_idx][bin_idx].append(
|
||||||
y[keep].detach().to(torch.bool).cpu().numpy()
|
y[keep].detach().to(dtype=torch.bool,
|
||||||
|
device="cpu", non_blocking=True)
|
||||||
)
|
)
|
||||||
y_pred[mc_idx][tau_idx][bin_idx].append(
|
y_pred[mc_idx][tau_idx][bin_idx].append(
|
||||||
preds[keep].detach().to(torch.float32).cpu().numpy()
|
preds[keep].detach().to(dtype=torch.float32,
|
||||||
|
device="cpu", non_blocking=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
rows_by_bin: List[Dict[str, float | int]] = []
|
rows_by_bin: List[Dict[str, float | int]] = []
|
||||||
@@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
|
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
|
||||||
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
|
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
|
||||||
if yb.shape != pb.shape:
|
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
|
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yb = yb_t.numpy()
|
||||||
|
pb = pb_t.numpy()
|
||||||
|
|
||||||
n_samples = int(yb.shape[0])
|
n_samples = int(yb.shape[0])
|
||||||
|
|
||||||
for cause_k in range(n_causes_eval):
|
for cause_k in range(n_causes_eval):
|
||||||
|
|||||||
38
utils.py
38
utils.py
@@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin(
|
|||||||
eligible = valid & in_bin & (
|
eligible = valid & in_bin & (
|
||||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||||
|
|
||||||
keep = torch.zeros((B,), dtype=torch.bool, device=device)
|
# Vectorized, uniform sampling over eligible indices per sample.
|
||||||
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
|
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
|
||||||
|
# choice among eligible indices by symmetry (ties have probability ~0).
|
||||||
|
keep = eligible.any(dim=1)
|
||||||
|
|
||||||
gen = torch.Generator(device="cpu")
|
# Prefer a per-call generator on the target device for reproducibility without
|
||||||
|
# touching global RNG state. If unavailable, fall back to seeding the global
|
||||||
|
# CUDA RNG for this call.
|
||||||
|
gen = None
|
||||||
|
if device.type == "cuda":
|
||||||
|
try:
|
||||||
|
gen = torch.Generator(device=device)
|
||||||
|
gen.manual_seed(int(seed))
|
||||||
|
except Exception:
|
||||||
|
gen = None
|
||||||
|
torch.cuda.manual_seed(int(seed))
|
||||||
|
else:
|
||||||
|
gen = torch.Generator()
|
||||||
gen.manual_seed(int(seed))
|
gen.manual_seed(int(seed))
|
||||||
|
|
||||||
for i in range(B):
|
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
|
||||||
m = eligible[i]
|
r = r.masked_fill(~eligible, -1.0)
|
||||||
if not m.any():
|
t_ctx = r.argmax(dim=1).to(torch.long)
|
||||||
continue
|
|
||||||
|
|
||||||
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
|
|
||||||
chosen_idx_pos = int(
|
|
||||||
torch.randint(low=0, high=int(idxs.numel()),
|
|
||||||
size=(1,), generator=gen).item()
|
|
||||||
)
|
|
||||||
chosen_t = int(idxs[chosen_idx_pos].item())
|
|
||||||
|
|
||||||
keep[i] = True
|
|
||||||
t_ctx[i] = chosen_t
|
|
||||||
|
|
||||||
|
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
|
||||||
return keep, t_ctx
|
return keep, t_ctx
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user