import torch from typing import Tuple DAYS_PER_YEAR = 365.25 def sample_context_in_fixed_age_bin( event_seq: torch.Tensor, time_seq: torch.Tensor, tau_years: float, age_bin: Tuple[float, float], seed: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample one context token per individual within a fixed age bin. Delphi-2M semantics for a specific (tau, age_bin): - Token times are interpreted as age in *days* (converted to years). - Follow-up end time is the last valid token time per individual. - A token index j is eligible iff: (token is valid) AND (age_years in [age_low, age_high)) AND (time_seq[i, j] + tau_days <= followup_end_time[i]) - For each individual, randomly select exactly one eligible token in this bin. Args: event_seq: (B, L) token ids, 0 is padding. time_seq: (B, L) token times in days. tau_years: horizon length in years. age_bin: (low, high) bounds in years, interpreted as [low, high). seed: RNG seed for deterministic sampling. Returns: keep: (B,) bool, True if a context was sampled for this bin. t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0). """ low, high = float(age_bin[0]), float(age_bin[1]) if not (high > low): raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}") device = event_seq.device B, _ = event_seq.shape valid = event_seq != 0 lengths = valid.sum(dim=1) last_idx = torch.clamp(lengths - 1, min=0) b = torch.arange(B, device=device) followup_end_time = time_seq[b, last_idx] # (B,) tau_days = float(tau_years) * DAYS_PER_YEAR age_years = time_seq / DAYS_PER_YEAR in_bin = (age_years >= low) & (age_years < high) eligible = valid & in_bin & ( (time_seq + tau_days) <= followup_end_time.unsqueeze(1)) keep = torch.zeros((B,), dtype=torch.bool, device=device) t_ctx = torch.zeros((B,), dtype=torch.long, device=device) gen = torch.Generator(device="cpu") gen.manual_seed(int(seed)) for i in range(B): m = eligible[i] if not m.any(): continue idxs = m.nonzero(as_tuple=False).view(-1).cpu() chosen_idx_pos = int( torch.randint(low=0, high=int(idxs.numel()), size=(1,), generator=gen).item() ) chosen_t = int(idxs[chosen_idx_pos].item()) keep[i] = True t_ctx[i] = chosen_t return keep, t_ctx def select_context_indices( event_seq: torch.Tensor, time_seq: torch.Tensor, offset_years: float, tau_years: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Select per-sample prediction context index. IMPORTANT SEMANTICS: - The last observed token time is treated as the FOLLOW-UP END time. - We pick the last valid token with time <= (followup_end_time - offset). - We do NOT interpret followup_end_time as an event time. Returns: keep_mask: (B,) bool, which samples have a valid context t_ctx: (B,) long, index into sequence t_ctx_time: (B,) float, time (days) at context """ # valid tokens are event != 0 (padding is 0) valid = event_seq != 0 lengths = valid.sum(dim=1) last_idx = torch.clamp(lengths - 1, min=0) b = torch.arange(event_seq.size(0), device=event_seq.device) followup_end_time = time_seq[b, last_idx] t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) eligible = valid & (time_seq <= t_cut.unsqueeze(1)) eligible_counts = eligible.sum(dim=1) keep = eligible_counts > 0 t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) t_ctx_time = time_seq[b, t_ctx] # Horizon-aligned eligibility: require enough follow-up time after the selected context. # All times are in days. keep = keep & (followup_end_time >= ( t_ctx_time + (tau_years * DAYS_PER_YEAR))) return keep, t_ctx, t_ctx_time def multi_hot_ever_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs within tau after context (any occurrence).""" B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) # Include same-day events after context, exclude any token at/before context index. in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) y[b_idx, disease_ids] = True return y def multi_hot_selected_causes_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, cause_ids: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Labels for selected causes only: does cause k occur within tau after context?""" B, L = event_seq.shape device = event_seq.device b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) if not in_window.any(): return out b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # Filter to selected causes via a boolean membership mask over the global disease space. selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) selected[cause_ids] = True keep = selected[disease_ids] if not keep.any(): return out b_idx = b_idx[keep] disease_ids = disease_ids[keep] # Map disease_id -> local index in cause_ids # Build a lookup table (global disease space) where lookup[disease_id] = local_index lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) local = lookup[disease_ids] out[b_idx, local] = True return out