import math from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Pair extraction (utility; not used by the losses below) # ============================================================ def get_valid_pairs_and_dt( event_seqs: torch.Tensor, time_seqs: torch.Tensor, n_tech_tokens: int ) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Extract valid event pairs (prev -> next) and compute dt in years. Args: event_seqs (torch.Tensor): Event sequences. time_seqs (torch.Tensor): Time sequences. n_tech_tokens (int): Number of technical tokens. Returns: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: (dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None. Notes: - Assumes strict right-padding. - Filters to next events that are disease tokens: token_id >= n_tech_tokens. - Filters to strictly positive dt. """ real_mask = event_seqs >= 1 idx = real_mask.nonzero(as_tuple=False) if idx.size(0) <= 1: return None same_batch = idx[1:, 0] == idx[:-1, 0] if not same_batch.any(): return None prev_idx = idx[:-1][same_batch] next_idx = idx[1:][same_batch] b_next, t_next = next_idx[:, 0], next_idx[:, 1] valid_target = event_seqs[b_next, t_next] >= n_tech_tokens if not valid_target.any(): return None prev_idx = prev_idx[valid_target] next_idx = next_idx[valid_target] b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1] b_next, t_next = next_idx[:, 0], next_idx[:, 1] dt = (time_seqs[b_next, t_next] - time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25 valid_dt = dt > 0 if not valid_dt.any(): return None dt = dt[valid_dt] b_prev = b_prev[valid_dt] t_prev = t_prev[valid_dt] b_next = b_next[valid_dt] t_next = t_next[valid_dt] return dt, b_prev, t_prev, b_next, t_next # ============================================================ # Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization) # ============================================================ class ExponentialNLLLoss(nn.Module): """ Competing risks exponential likelihood. The negative log-likelihood is given by: .. math:: \\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt Args: eps (float): Small epsilon for numerical stability. """ def __init__( self, lambda_reg: float = 0.0, eps: float = 1e-6, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K) tensor of logits. target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K). dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0. """ logits = logits.squeeze(-1) if logits.dim() == 3 else logits hazards = F.softplus(logits) + self.eps # (M,K) hazard_event = hazards.gather( 1, target_events.unsqueeze(1)).squeeze(1) # (M,) total_hazard = hazards.sum(dim=1) # (M,) log_hazards = torch.log(hazards) # (M, K) nll = -torch.log(hazard_event) + total_hazard * dt if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = F.cross_entropy(log_hazards, target_events, reduction="mean") * self.lambda_reg return nll, reg class PiecewiseExponentialLoss(nn.Module): """ Piecewise-constant competing risks exponential likelihood. Lightweight numerical protections: - Does NOT mask/skip any samples. - Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation. - Clamps logits and dt to keep softplus/log operations finite. """ def __init__( self, bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, logit_clip: float = 30.0, ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2") if bin_edges[0] != 0: raise ValueError("bin_edges must start at 0") for i in range(1, len(bin_edges)): if not (bin_edges[i] > bin_edges[i - 1]): raise ValueError("bin_edges must be strictly increasing") self.eps = float(eps) self.lambda_reg = float(lambda_reg) self.logit_clip = float(logit_clip) edges = torch.tensor(list(bin_edges), dtype=torch.float32) self.register_buffer("bin_edges", edges, persistent=False) def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if logits.dim() != 3: raise ValueError("logits must have shape (M, K, B)") M, K, B = logits.shape if self.bin_edges.numel() != B + 1: raise ValueError( f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})" ) device = logits.device dt = dt.to(device=device, dtype=torch.float32) target_events = target_events.to(device=device) # No masking/skipping: coerce invalid values to safe defaults. logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) logits_v = torch.clamp( logits_v, min=-self.logit_clip, max=self.logit_clip) dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) target_v = torch.nan_to_num( target_events, nan=0.0, posinf=0.0, neginf=0.0) target_v = target_v.to(dtype=torch.long) target_v = torch.clamp(target_v, min=0, max=K - 1) # Keep structural clamping to prevent index-out-of-bounds errors # (Necessary for searchsorted/gather to work at all) eps = self.eps max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype) dt_max = torch.nextafter(max_edge, max_edge.new_zeros(())) dt_v = torch.clamp(dt_v, min=eps, max=dt_max) # Hazards: (M, K, B) hazards = F.softplus(logits_v) + eps hazards = torch.clamp(hazards, min=eps) total_hazard = hazards.sum(dim=1) # (M, B) edges = self.bin_edges.to(device=device, dtype=dt_v.dtype) widths = edges[1:] - edges[:-1] # (B,) # Bin index b* in [0, B-1]. b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,) b_star = torch.clamp(b_star, min=0, max=B - 1) # 1. Hazard at event (M,) # gather needs matching dims. # hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,) # Alternative: hazards[m, k, b] ar = torch.arange(M, device=device) hazard_event = hazards[ar, target_v, b_star] # (M,) hazard_event = torch.clamp(hazard_event, min=eps) # 2. Integral part # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) # Full bins accumulation weighted = total_hazard * widths.unsqueeze(0) # (M, B) cum = weighted.cumsum(dim=1) # (M, B) full_bins_int = torch.zeros_like(dt_v) # We process 'has_full' logic generally. # If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional has_full = b_star > 0 # NOTE: Even without protection, we need valid indices for gather. # We use a temporary index that is safe (0) for the 'False' cases, then mask the result. safe_indices = (b_star - 1).clamp(min=0) gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1) full_bins_int = torch.where(has_full, gathered_cum, full_bins_int) # Partial bin accumulation edge_left = edges[b_star] # (M,) partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1) partial = partial_hazard * (dt_v - edge_left) integral = full_bins_int + partial # Final NLL nll = -torch.log(hazard_event) + integral # Reduction if reduction == "none": nll_out = nll elif reduction == "sum": nll_out = nll.sum() elif reduction == "mean": nll_out = nll.mean() else: raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") reg = logits.new_zeros(()) if self.lambda_reg != 0.0: reg = reg + (self.lambda_reg * logits_v.pow(2).mean()) return nll_out, reg class WeibullNLLLoss(nn.Module): """ Weibull hazard in t with lightweight numerical protections. Does NOT mask/skip any samples. Instead: - nan_to_num for logits/dt/targets - clamps logits to keep softplus outputs reasonable - computes t^shape in log-space with clamped exponent to prevent overflow """ def __init__( self, eps: float = 1e-6, lambda_reg: float = 0.0, logit_clip: float = 30.0, max_shape: float = 30.0, max_dt: float = 1.0e3, max_exp: float = 80.0, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg self.logit_clip = float(logit_clip) self.max_shape = float(max_shape) self.max_dt = float(max_dt) self.max_exp = float(max_exp) def forward(self, logits, target_events, dt, reduction="mean"): if logits.dim() != 3 or logits.size(-1) != 2: raise ValueError("logits must have shape (M, K, 2)") M, K, _ = logits.shape device = logits.device logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip) dt = dt.to(device=device, dtype=torch.float32) dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) dt = torch.clamp(dt, min=self.eps, max=self.max_dt) target_events = target_events.to(device=device) target_events = torch.nan_to_num( target_events, nan=0.0, posinf=0.0, neginf=0.0) target_events = target_events.to(dtype=torch.long) target_events = torch.clamp(target_events, min=0, max=K - 1) shapes = F.softplus(logits[..., 0]) + self.eps scales = F.softplus(logits[..., 1]) + self.eps shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape) scales = torch.clamp(scales, min=self.eps) t_mat = dt.unsqueeze(1) # (M,1) log_t = torch.log(torch.clamp(t_mat, min=self.eps)) # Compute t^shape and t^(shape-1) in log-space with exponent clamp. pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp)) pow_shape_minus_1 = torch.exp( torch.clamp((shapes - 1.0) * log_t, max=self.max_exp) ) cum_hazard = scales * pow_shape hazard = shapes * scales * pow_shape_minus_1 hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) hazard_event = torch.clamp(hazard_event, min=self.eps) nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1) if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() elif reduction != "none": raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") reg = shapes.new_zeros(()) if self.lambda_reg > 0: reg = self.lambda_reg * ( (torch.log(scales + self.eps) ** 2).mean() + (torch.log(shapes + self.eps) ** 2).mean() ) return nll, reg