Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py

This commit is contained in:
2026-01-16 16:13:31 +08:00
parent 502ddd153b
commit 90dffc3211
4 changed files with 597 additions and 349 deletions

View File

@@ -4,6 +4,79 @@ from typing import Tuple
DAYS_PER_YEAR = 365.25
def sample_context_in_fixed_age_bin(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
tau_years: float,
age_bin: Tuple[float, float],
seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample one context token per individual within a fixed age bin.
Delphi-2M semantics for a specific (tau, age_bin):
- Token times are interpreted as age in *days* (converted to years).
- Follow-up end time is the last valid token time per individual.
- A token index j is eligible iff:
(token is valid)
AND (age_years in [age_low, age_high))
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
- For each individual, randomly select exactly one eligible token in this bin.
Args:
event_seq: (B, L) token ids, 0 is padding.
time_seq: (B, L) token times in days.
tau_years: horizon length in years.
age_bin: (low, high) bounds in years, interpreted as [low, high).
seed: RNG seed for deterministic sampling.
Returns:
keep: (B,) bool, True if a context was sampled for this bin.
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
"""
low, high = float(age_bin[0]), float(age_bin[1])
if not (high > low):
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
device = event_seq.device
B, _ = event_seq.shape
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(B, device=device)
followup_end_time = time_seq[b, last_idx] # (B,)
tau_days = float(tau_years) * DAYS_PER_YEAR
age_years = time_seq / DAYS_PER_YEAR
in_bin = (age_years >= low) & (age_years < high)
eligible = valid & in_bin & (
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
keep = torch.zeros((B,), dtype=torch.bool, device=device)
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed))
for i in range(B):
m = eligible[i]
if not m.any():
continue
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
chosen_idx_pos = int(
torch.randint(low=0, high=int(idxs.numel()),
size=(1,), generator=gen).item()
)
chosen_t = int(idxs[chosen_idx_pos].item())
keep[i] = True
t_ctx[i] = chosen_t
return keep, t_ctx
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,