Remove evaluation_age_time_dependent.py and utils.py files
- Deleted the entire evaluation_age_time_dependent.py file which contained functions for evaluating age-dependent metrics, including various statistical calculations and data aggregation methods. - Removed utils.py file that provided utility functions for sampling context in fixed age bins and multi-hot encoding for disease occurrences.
This commit is contained in:
207
utils.py
207
utils.py
@@ -1,207 +0,0 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
DAYS_PER_YEAR = 365.25
|
||||
|
||||
|
||||
def sample_context_in_fixed_age_bin(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
tau_years: float,
|
||||
age_bin: Tuple[float, float],
|
||||
seed: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Sample one context token per individual within a fixed age bin.
|
||||
|
||||
Delphi-2M semantics for a specific (tau, age_bin):
|
||||
- Token times are interpreted as age in *days* (converted to years).
|
||||
- Follow-up end time is the last valid token time per individual.
|
||||
- A token index j is eligible iff:
|
||||
(token is valid)
|
||||
AND (age_years in [age_low, age_high))
|
||||
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
|
||||
- For each individual, randomly select exactly one eligible token in this bin.
|
||||
|
||||
Args:
|
||||
event_seq: (B, L) token ids, 0 is padding.
|
||||
time_seq: (B, L) token times in days.
|
||||
tau_years: horizon length in years.
|
||||
age_bin: (low, high) bounds in years, interpreted as [low, high).
|
||||
seed: RNG seed for deterministic sampling.
|
||||
|
||||
Returns:
|
||||
keep: (B,) bool, True if a context was sampled for this bin.
|
||||
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
|
||||
"""
|
||||
low, high = float(age_bin[0]), float(age_bin[1])
|
||||
if not (high > low):
|
||||
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
|
||||
|
||||
device = event_seq.device
|
||||
B, _ = event_seq.shape
|
||||
|
||||
valid = event_seq != 0
|
||||
lengths = valid.sum(dim=1)
|
||||
last_idx = torch.clamp(lengths - 1, min=0)
|
||||
b = torch.arange(B, device=device)
|
||||
followup_end_time = time_seq[b, last_idx] # (B,)
|
||||
|
||||
tau_days = float(tau_years) * DAYS_PER_YEAR
|
||||
age_years = time_seq / DAYS_PER_YEAR
|
||||
|
||||
in_bin = (age_years >= low) & (age_years < high)
|
||||
eligible = valid & in_bin & (
|
||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||
|
||||
# Vectorized, uniform sampling over eligible indices per sample.
|
||||
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
|
||||
# choice among eligible indices by symmetry (ties have probability ~0).
|
||||
keep = eligible.any(dim=1)
|
||||
|
||||
# Prefer a per-call generator on the target device for reproducibility without
|
||||
# touching global RNG state. If unavailable, fall back to seeding the global
|
||||
# CUDA RNG for this call.
|
||||
gen = None
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(int(seed))
|
||||
except Exception:
|
||||
gen = None
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
else:
|
||||
gen = torch.Generator()
|
||||
gen.manual_seed(int(seed))
|
||||
|
||||
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
|
||||
r = r.masked_fill(~eligible, -1.0)
|
||||
t_ctx = r.argmax(dim=1).to(torch.long)
|
||||
|
||||
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
|
||||
return keep, t_ctx
|
||||
|
||||
|
||||
def select_context_indices(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
offset_years: float,
|
||||
tau_years: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Select per-sample prediction context index.
|
||||
|
||||
IMPORTANT SEMANTICS:
|
||||
- The last observed token time is treated as the FOLLOW-UP END time.
|
||||
- We pick the last valid token with time <= (followup_end_time - offset).
|
||||
- We do NOT interpret followup_end_time as an event time.
|
||||
|
||||
Returns:
|
||||
keep_mask: (B,) bool, which samples have a valid context
|
||||
t_ctx: (B,) long, index into sequence
|
||||
t_ctx_time: (B,) float, time (days) at context
|
||||
"""
|
||||
# valid tokens are event != 0 (padding is 0)
|
||||
valid = event_seq != 0
|
||||
lengths = valid.sum(dim=1)
|
||||
last_idx = torch.clamp(lengths - 1, min=0)
|
||||
|
||||
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
||||
followup_end_time = time_seq[b, last_idx]
|
||||
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
||||
|
||||
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
||||
eligible_counts = eligible.sum(dim=1)
|
||||
keep = eligible_counts > 0
|
||||
|
||||
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
||||
t_ctx_time = time_seq[b, t_ctx]
|
||||
|
||||
# Horizon-aligned eligibility: require enough follow-up time after the selected context.
|
||||
# All times are in days.
|
||||
keep = keep & (followup_end_time >= (
|
||||
t_ctx_time + (tau_years * DAYS_PER_YEAR)))
|
||||
|
||||
return keep, t_ctx, t_ctx_time
|
||||
|
||||
|
||||
def multi_hot_ever_within_horizon(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
t_ctx: torch.Tensor,
|
||||
tau_years: float,
|
||||
n_disease: int,
|
||||
) -> torch.Tensor:
|
||||
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
||||
B, L = event_seq.shape
|
||||
b = torch.arange(B, device=event_seq.device)
|
||||
t0 = time_seq[b, t_ctx]
|
||||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||||
|
||||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||||
# Include same-day events after context, exclude any token at/before context index.
|
||||
in_window = (
|
||||
(idxs > t_ctx.unsqueeze(1))
|
||||
& (time_seq >= t0.unsqueeze(1))
|
||||
& (time_seq <= t1.unsqueeze(1))
|
||||
& (event_seq >= 2)
|
||||
& (event_seq != 0)
|
||||
)
|
||||
|
||||
if not in_window.any():
|
||||
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||||
|
||||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||||
|
||||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||||
y[b_idx, disease_ids] = True
|
||||
return y
|
||||
|
||||
|
||||
def multi_hot_selected_causes_within_horizon(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
t_ctx: torch.Tensor,
|
||||
tau_years: float,
|
||||
cause_ids: torch.Tensor,
|
||||
n_disease: int,
|
||||
) -> torch.Tensor:
|
||||
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
||||
B, L = event_seq.shape
|
||||
device = event_seq.device
|
||||
b = torch.arange(B, device=device)
|
||||
t0 = time_seq[b, t_ctx]
|
||||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||||
|
||||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||||
in_window = (
|
||||
(idxs > t_ctx.unsqueeze(1))
|
||||
& (time_seq >= t0.unsqueeze(1))
|
||||
& (time_seq <= t1.unsqueeze(1))
|
||||
& (event_seq >= 2)
|
||||
& (event_seq != 0)
|
||||
)
|
||||
|
||||
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
||||
if not in_window.any():
|
||||
return out
|
||||
|
||||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||||
|
||||
# Filter to selected causes via a boolean membership mask over the global disease space.
|
||||
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
||||
selected[cause_ids] = True
|
||||
keep = selected[disease_ids]
|
||||
if not keep.any():
|
||||
return out
|
||||
|
||||
b_idx = b_idx[keep]
|
||||
disease_ids = disease_ids[keep]
|
||||
|
||||
# Map disease_id -> local index in cause_ids
|
||||
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
||||
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
||||
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
||||
local = lookup[disease_ids]
|
||||
out[b_idx, local] = True
|
||||
return out
|
||||
Reference in New Issue
Block a user