Refactor tqdm import handling and improve context sampling in utils.py
This commit is contained in:
@@ -8,9 +8,12 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
def tqdm(x, **kwargs):
|
||||
return x
|
||||
from utils import (
|
||||
multi_hot_ever_within_horizon,
|
||||
multi_hot_selected_causes_within_horizon,
|
||||
@@ -281,7 +284,7 @@ class EvalAgeConfig:
|
||||
cause_ids: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def evaluate_time_dependent_age_bins(
|
||||
model: torch.nn.Module,
|
||||
head: torch.nn.Module,
|
||||
@@ -338,12 +341,12 @@ def evaluate_time_dependent_age_bins(
|
||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||
n_causes_eval = int(cause_ids.numel())
|
||||
|
||||
# Storage: (mc, h, bin) -> list of arrays
|
||||
y_true: List[List[List[List[np.ndarray]]]] = [
|
||||
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
|
||||
y_true: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
y_pred: List[List[List[List[np.ndarray]]]] = [
|
||||
y_pred: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
@@ -365,7 +368,13 @@ def evaluate_time_dependent_age_bins(
|
||||
B = int(event_seq.size(0))
|
||||
b = torch.arange(B, device=device)
|
||||
|
||||
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
|
||||
# within this batch, so this is safe and numerically identical.
|
||||
h = model(event_seq, time_seq, sexes,
|
||||
cont_feats, cate_feats) # (B,L,D)
|
||||
|
||||
for tau_idx, tau_y in enumerate(horizons_years):
|
||||
tau_tensor = torch.tensor(float(tau_y), device=device)
|
||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||
seed = (
|
||||
@@ -386,14 +395,13 @@ def evaluate_time_dependent_age_bins(
|
||||
if not keep.any():
|
||||
continue
|
||||
|
||||
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
|
||||
h = model(event_seq, time_seq, sexes,
|
||||
cont_feats, cate_feats) # (B,L,D)
|
||||
# Bin-specific prediction: context indices differ per (tau, bin)
|
||||
# but the backbone features do not.
|
||||
c = h[b, t_ctx]
|
||||
logits = head(c)
|
||||
|
||||
cifs = criterion.calculate_cifs(
|
||||
logits, taus=torch.tensor(float(tau_y), device=device)
|
||||
logits, taus=tau_tensor
|
||||
)
|
||||
if cifs.ndim != 2:
|
||||
raise ValueError(
|
||||
@@ -421,11 +429,15 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
preds = cifs.index_select(dim=1, index=cause_ids)
|
||||
|
||||
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
|
||||
# and convert to NumPy once during aggregation.
|
||||
y_true[mc_idx][tau_idx][bin_idx].append(
|
||||
y[keep].detach().to(torch.bool).cpu().numpy()
|
||||
y[keep].detach().to(dtype=torch.bool,
|
||||
device="cpu", non_blocking=True)
|
||||
)
|
||||
y_pred[mc_idx][tau_idx][bin_idx].append(
|
||||
preds[keep].detach().to(torch.float32).cpu().numpy()
|
||||
preds[keep].detach().to(dtype=torch.float32,
|
||||
device="cpu", non_blocking=True)
|
||||
)
|
||||
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
@@ -460,13 +472,16 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
continue
|
||||
|
||||
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
|
||||
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
|
||||
if yb.shape != pb.shape:
|
||||
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
|
||||
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
|
||||
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
||||
raise ValueError(
|
||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
|
||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
||||
)
|
||||
|
||||
yb = yb_t.numpy()
|
||||
pb = pb_t.numpy()
|
||||
|
||||
n_samples = int(yb.shape[0])
|
||||
|
||||
for cause_k in range(n_causes_eval):
|
||||
|
||||
Reference in New Issue
Block a user