- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference. - Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds. - Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models. - Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
import torch
|
|
from typing import Tuple
|
|
|
|
DAYS_PER_YEAR = 365.25
|
|
|
|
|
|
def select_context_indices(
|
|
event_seq: torch.Tensor,
|
|
time_seq: torch.Tensor,
|
|
offset_years: float,
|
|
tau_years: float = 0.0,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Select per-sample prediction context index.
|
|
|
|
IMPORTANT SEMANTICS:
|
|
- The last observed token time is treated as the FOLLOW-UP END time.
|
|
- We pick the last valid token with time <= (followup_end_time - offset).
|
|
- We do NOT interpret followup_end_time as an event time.
|
|
|
|
Returns:
|
|
keep_mask: (B,) bool, which samples have a valid context
|
|
t_ctx: (B,) long, index into sequence
|
|
t_ctx_time: (B,) float, time (days) at context
|
|
"""
|
|
# valid tokens are event != 0 (padding is 0)
|
|
valid = event_seq != 0
|
|
lengths = valid.sum(dim=1)
|
|
last_idx = torch.clamp(lengths - 1, min=0)
|
|
|
|
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
|
followup_end_time = time_seq[b, last_idx]
|
|
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
|
|
|
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
|
eligible_counts = eligible.sum(dim=1)
|
|
keep = eligible_counts > 0
|
|
|
|
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
|
t_ctx_time = time_seq[b, t_ctx]
|
|
|
|
# Horizon-aligned eligibility: require enough follow-up time after the selected context.
|
|
# All times are in days.
|
|
keep = keep & (followup_end_time >= (
|
|
t_ctx_time + (tau_years * DAYS_PER_YEAR)))
|
|
|
|
return keep, t_ctx, t_ctx_time
|
|
|
|
|
|
def multi_hot_ever_within_horizon(
|
|
event_seq: torch.Tensor,
|
|
time_seq: torch.Tensor,
|
|
t_ctx: torch.Tensor,
|
|
tau_years: float,
|
|
n_disease: int,
|
|
) -> torch.Tensor:
|
|
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
|
B, L = event_seq.shape
|
|
b = torch.arange(B, device=event_seq.device)
|
|
t0 = time_seq[b, t_ctx]
|
|
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
|
|
|
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
|
# Include same-day events after context, exclude any token at/before context index.
|
|
in_window = (
|
|
(idxs > t_ctx.unsqueeze(1))
|
|
& (time_seq >= t0.unsqueeze(1))
|
|
& (time_seq <= t1.unsqueeze(1))
|
|
& (event_seq >= 2)
|
|
& (event_seq != 0)
|
|
)
|
|
|
|
if not in_window.any():
|
|
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
|
|
|
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
|
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
|
|
|
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
|
y[b_idx, disease_ids] = True
|
|
return y
|
|
|
|
|
|
def multi_hot_selected_causes_within_horizon(
|
|
event_seq: torch.Tensor,
|
|
time_seq: torch.Tensor,
|
|
t_ctx: torch.Tensor,
|
|
tau_years: float,
|
|
cause_ids: torch.Tensor,
|
|
n_disease: int,
|
|
) -> torch.Tensor:
|
|
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
|
B, L = event_seq.shape
|
|
device = event_seq.device
|
|
b = torch.arange(B, device=device)
|
|
t0 = time_seq[b, t_ctx]
|
|
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
|
|
|
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
|
in_window = (
|
|
(idxs > t_ctx.unsqueeze(1))
|
|
& (time_seq >= t0.unsqueeze(1))
|
|
& (time_seq <= t1.unsqueeze(1))
|
|
& (event_seq >= 2)
|
|
& (event_seq != 0)
|
|
)
|
|
|
|
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
|
if not in_window.any():
|
|
return out
|
|
|
|
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
|
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
|
|
|
# Filter to selected causes via a boolean membership mask over the global disease space.
|
|
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
|
selected[cause_ids] = True
|
|
keep = selected[disease_ids]
|
|
if not keep.any():
|
|
return out
|
|
|
|
b_idx = b_idx[keep]
|
|
disease_ids = disease_ids[keep]
|
|
|
|
# Map disease_id -> local index in cause_ids
|
|
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
|
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
|
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
|
local = lookup[disease_ids]
|
|
out[b_idx, local] = True
|
|
return out
|