Refactor tqdm import handling and improve context sampling in utils.py

This commit is contained in:
2026-01-16 16:57:35 +08:00
parent e47a7ce4d6
commit b1647d1b74
2 changed files with 53 additions and 34 deletions

View File

@@ -8,9 +8,12 @@ import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
@@ -281,7 +284,7 @@ class EvalAgeConfig:
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
@torch.inference_mode()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
@@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of arrays
y_true: List[List[List[List[np.ndarray]]]] = [
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
y_true: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[np.ndarray]]]] = [
y_pred: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
@@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins(
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
# within this batch, so this is safe and numerically identical.
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
for tau_idx, tau_y in enumerate(horizons_years):
tau_tensor = torch.tensor(float(tau_y), device=device)
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
@@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins(
if not keep.any():
continue
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
# Bin-specific prediction: context indices differ per (tau, bin)
# but the backbone features do not.
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=torch.tensor(float(tau_y), device=device)
logits, taus=tau_tensor
)
if cifs.ndim != 2:
raise ValueError(
@@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins(
)
preds = cifs.index_select(dim=1, index=cause_ids)
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
# and convert to NumPy once during aggregation.
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(torch.bool).cpu().numpy()
y[keep].detach().to(dtype=torch.bool,
device="cpu", non_blocking=True)
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(torch.float32).cpu().numpy()
preds[keep].detach().to(dtype=torch.float32,
device="cpu", non_blocking=True)
)
rows_by_bin: List[Dict[str, float | int]] = []
@@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins(
)
continue
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
if yb.shape != pb.shape:
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
yb = yb_t.numpy()
pb = pb_t.numpy()
n_samples = int(yb.shape[0])
for cause_k in range(n_causes_eval):