Refactor tqdm import handling and improve context sampling in utils.py
This commit is contained in:
40
utils.py
40
utils.py
@@ -53,27 +53,31 @@ def sample_context_in_fixed_age_bin(
|
||||
eligible = valid & in_bin & (
|
||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||
|
||||
keep = torch.zeros((B,), dtype=torch.bool, device=device)
|
||||
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
|
||||
# Vectorized, uniform sampling over eligible indices per sample.
|
||||
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
|
||||
# choice among eligible indices by symmetry (ties have probability ~0).
|
||||
keep = eligible.any(dim=1)
|
||||
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(int(seed))
|
||||
# Prefer a per-call generator on the target device for reproducibility without
|
||||
# touching global RNG state. If unavailable, fall back to seeding the global
|
||||
# CUDA RNG for this call.
|
||||
gen = None
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(int(seed))
|
||||
except Exception:
|
||||
gen = None
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
else:
|
||||
gen = torch.Generator()
|
||||
gen.manual_seed(int(seed))
|
||||
|
||||
for i in range(B):
|
||||
m = eligible[i]
|
||||
if not m.any():
|
||||
continue
|
||||
|
||||
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
|
||||
chosen_idx_pos = int(
|
||||
torch.randint(low=0, high=int(idxs.numel()),
|
||||
size=(1,), generator=gen).item()
|
||||
)
|
||||
chosen_t = int(idxs[chosen_idx_pos].item())
|
||||
|
||||
keep[i] = True
|
||||
t_ctx[i] = chosen_t
|
||||
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
|
||||
r = r.masked_fill(~eligible, -1.0)
|
||||
t_ctx = r.argmax(dim=1).to(torch.long)
|
||||
|
||||
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
|
||||
return keep, t_ctx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user