Refactor tqdm import handling and improve context sampling in utils.py

This commit is contained in:
2026-01-16 16:57:35 +08:00
parent e47a7ce4d6
commit b1647d1b74
2 changed files with 53 additions and 34 deletions

View File

@@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin(
eligible = valid & in_bin & (
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
keep = torch.zeros((B,), dtype=torch.bool, device=device)
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
# Vectorized, uniform sampling over eligible indices per sample.
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
# choice among eligible indices by symmetry (ties have probability ~0).
keep = eligible.any(dim=1)
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed))
# Prefer a per-call generator on the target device for reproducibility without
# touching global RNG state. If unavailable, fall back to seeding the global
# CUDA RNG for this call.
gen = None
if device.type == "cuda":
try:
gen = torch.Generator(device=device)
gen.manual_seed(int(seed))
except Exception:
gen = None
torch.cuda.manual_seed(int(seed))
else:
gen = torch.Generator()
gen.manual_seed(int(seed))
for i in range(B):
m = eligible[i]
if not m.any():
continue
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
chosen_idx_pos = int(
torch.randint(low=0, high=int(idxs.numel()),
size=(1,), generator=gen).item()
)
chosen_t = int(idxs[chosen_idx_pos].item())
keep[i] = True
t_ctx[i] = chosen_t
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
r = r.masked_fill(~eligible, -1.0)
t_ctx = r.argmax(dim=1).to(torch.long)
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
return keep, t_ctx