Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py
This commit is contained in:
73
utils.py
73
utils.py
@@ -4,6 +4,79 @@ from typing import Tuple
|
||||
DAYS_PER_YEAR = 365.25
|
||||
|
||||
|
||||
def sample_context_in_fixed_age_bin(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
tau_years: float,
|
||||
age_bin: Tuple[float, float],
|
||||
seed: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Sample one context token per individual within a fixed age bin.
|
||||
|
||||
Delphi-2M semantics for a specific (tau, age_bin):
|
||||
- Token times are interpreted as age in *days* (converted to years).
|
||||
- Follow-up end time is the last valid token time per individual.
|
||||
- A token index j is eligible iff:
|
||||
(token is valid)
|
||||
AND (age_years in [age_low, age_high))
|
||||
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
|
||||
- For each individual, randomly select exactly one eligible token in this bin.
|
||||
|
||||
Args:
|
||||
event_seq: (B, L) token ids, 0 is padding.
|
||||
time_seq: (B, L) token times in days.
|
||||
tau_years: horizon length in years.
|
||||
age_bin: (low, high) bounds in years, interpreted as [low, high).
|
||||
seed: RNG seed for deterministic sampling.
|
||||
|
||||
Returns:
|
||||
keep: (B,) bool, True if a context was sampled for this bin.
|
||||
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
|
||||
"""
|
||||
low, high = float(age_bin[0]), float(age_bin[1])
|
||||
if not (high > low):
|
||||
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
|
||||
|
||||
device = event_seq.device
|
||||
B, _ = event_seq.shape
|
||||
|
||||
valid = event_seq != 0
|
||||
lengths = valid.sum(dim=1)
|
||||
last_idx = torch.clamp(lengths - 1, min=0)
|
||||
b = torch.arange(B, device=device)
|
||||
followup_end_time = time_seq[b, last_idx] # (B,)
|
||||
|
||||
tau_days = float(tau_years) * DAYS_PER_YEAR
|
||||
age_years = time_seq / DAYS_PER_YEAR
|
||||
|
||||
in_bin = (age_years >= low) & (age_years < high)
|
||||
eligible = valid & in_bin & (
|
||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||
|
||||
keep = torch.zeros((B,), dtype=torch.bool, device=device)
|
||||
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
|
||||
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(int(seed))
|
||||
|
||||
for i in range(B):
|
||||
m = eligible[i]
|
||||
if not m.any():
|
||||
continue
|
||||
|
||||
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
|
||||
chosen_idx_pos = int(
|
||||
torch.randint(low=0, high=int(idxs.numel()),
|
||||
size=(1,), generator=gen).item()
|
||||
)
|
||||
chosen_t = int(idxs[chosen_idx_pos].item())
|
||||
|
||||
keep[i] = True
|
||||
t_ctx[i] = chosen_t
|
||||
|
||||
return keep, t_ctx
|
||||
|
||||
|
||||
def select_context_indices(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user