import math from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Pair extraction (utility; not used by the losses below) # ============================================================ def get_valid_pairs_and_dt( event_seqs: torch.Tensor, time_seqs: torch.Tensor, n_tech_tokens: int ) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Extract valid event pairs (prev -> next) and compute dt in years. Args: event_seqs (torch.Tensor): Event sequences. time_seqs (torch.Tensor): Time sequences. n_tech_tokens (int): Number of technical tokens. Returns: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: (dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None. Notes: - Assumes strict right-padding. - Filters to next events that are disease tokens: token_id >= n_tech_tokens. - Filters to strictly positive dt. """ real_mask = event_seqs >= 1 idx = real_mask.nonzero(as_tuple=False) if idx.size(0) <= 1: return None same_batch = idx[1:, 0] == idx[:-1, 0] if not same_batch.any(): return None prev_idx = idx[:-1][same_batch] next_idx = idx[1:][same_batch] b_next, t_next = next_idx[:, 0], next_idx[:, 1] valid_target = event_seqs[b_next, t_next] >= n_tech_tokens if not valid_target.any(): return None prev_idx = prev_idx[valid_target] next_idx = next_idx[valid_target] b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1] b_next, t_next = next_idx[:, 0], next_idx[:, 1] dt = (time_seqs[b_next, t_next] - time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25 valid_dt = dt > 0 if not valid_dt.any(): return None dt = dt[valid_dt] b_prev = b_prev[valid_dt] t_prev = t_prev[valid_dt] b_next = b_next[valid_dt] t_next = t_next[valid_dt] return dt, b_prev, t_prev, b_next, t_next # ============================================================ # Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization) # ============================================================ class ExponentialNLLLoss(nn.Module): """ Competing risks exponential likelihood. The negative log-likelihood is given by: .. math:: \\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt Args: eps (float): Small epsilon for numerical stability. """ def __init__( self, lambda_reg: float = 0.0, eps: float = 1e-6, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K) tensor of logits. target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K). dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0. """ logits = logits.squeeze(-1) if logits.dim() == 3 else logits hazards = F.softplus(logits) + self.eps # (M,K) hazard_event = hazards.gather( 1, target_events.unsqueeze(1)).squeeze(1) # (M,) total_hazard = hazards.sum(dim=1) # (M,) log_hazards = torch.log(hazards) # (M, K) nll = -torch.log(hazard_event) + total_hazard * dt if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = F.cross_entropy(log_hazards, target_events, reduction="mean") * self.lambda_reg return nll, reg class PiecewiseExponentialLoss(nn.Module): """Piecewise-constant competing risks exponential likelihood. Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0). Within each bin b, hazards are constant and parameterized as: hazards = softplus(logits) + eps with logits shape (M, K, B) For each sample i, dt is bucketized to bin b* and the NLL is: nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du The integral is computed in closed form by summing full bins plus the partial bin b*. """ def __init__( self, bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2") if bin_edges[0] != 0: raise ValueError("bin_edges must start at 0") for i in range(1, len(bin_edges)): if not (bin_edges[i] > bin_edges[i - 1]): raise ValueError("bin_edges must be strictly increasing") self.eps = float(eps) self.lambda_reg = float(lambda_reg) edges = torch.tensor(list(bin_edges), dtype=torch.float32) self.register_buffer("bin_edges", edges, persistent=False) def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if logits.dim() != 3: raise ValueError("logits must have shape (M, K, B)") M, K, B = logits.shape if self.bin_edges.numel() != B + 1: raise ValueError( f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})" ) device = logits.device dt = dt.to(device=device) target_events = target_events.to(device=device) # Build a per-sample finite mask to avoid NaN/Inf propagation. logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1) dt_finite = torch.isfinite(dt) target_finite = torch.isfinite(target_events) finite_mask = logits_finite & dt_finite & target_finite nll_full = logits.new_zeros((M,)) if not finite_mask.any(): nll_out = nll_full if reduction == "none" else logits.new_zeros(()) reg_out = logits.new_zeros(()) return nll_out, reg_out idx = finite_mask.nonzero(as_tuple=False).squeeze(1) logits_v = logits[idx] target_v = target_events[idx].to(torch.long) dt_v = dt[idx].to(torch.float32) # Clamp dt into [eps, max_edge) to keep bucket indices valid. eps = self.eps max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype) dt_max = torch.nextafter(max_edge, max_edge.new_zeros(())) dt_v = torch.clamp(dt_v, min=eps, max=dt_max) hazards = F.softplus(logits_v) + eps # (Mv, K, B) total_hazard = hazards.sum(dim=1) # (Mv, B) edges = self.bin_edges.to(device=device, dtype=dt_v.dtype) widths = edges[1:] - edges[:-1] # (B,) # Bin index b* in [0, B-1]. boundaries are edges[1:] (length B). b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,) b_star = torch.clamp(b_star, min=0, max=B - 1) ar = torch.arange(logits_v.size(0), device=device) hazard_event = hazards[ar, target_v, b_star] # (Mv,) # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) weighted = total_hazard * widths.unsqueeze(0) # (Mv, B) cum = weighted.cumsum(dim=1) # (Mv, B) full_bins_int = torch.zeros_like(dt_v) has_full = b_star > 0 if has_full.any(): full_bins_int[has_full] = cum.gather( 1, (b_star[has_full] - 1).unsqueeze(1) ).squeeze(1) edge_left = edges[b_star] # (Mv,) partial = total_hazard.gather( 1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left) integral = full_bins_int + partial nll_v = -torch.log(hazard_event) + integral nll_full[idx] = nll_v if reduction == "none": nll_out = nll_full elif reduction == "sum": nll_out = nll_v.sum() elif reduction == "mean": nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(()) else: raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") reg = logits.new_zeros(()) if self.lambda_reg != 0.0: reg = reg + (self.lambda_reg * logits_v.pow(2).mean()) return nll_out, reg class WeibullNLLLoss(nn.Module): """ Weibull hazard in t. .. math:: \\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k} \\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1} Args: eps (float): Small epsilon for numerical stability. lambda_reg (float): Regularization weight. use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples. near_integer_eps_years (float): Near-integer threshold in years. interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years. min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year. """ def __init__( self, eps: float = 1e-6, lambda_reg: float = 0.0, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K, 2) tensor of logits. target_events (torch.Tensor): (M,) tensor of target events. dt (torch.Tensor): (M,) tensor of time intervals. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization). """ shapes = F.softplus(logits[..., 0]) + self.eps # (M,K) scales = F.softplus(logits[..., 1]) + self.eps # (M,K) eps = self.eps t = torch.clamp(dt, min=eps) t_mat = t.unsqueeze(1) # (M,1) # cumulative hazard (M,K) cum_hazard = scales * t_mat.pow(shapes) # hazard (M,K) hazard = shapes * scales * t_mat.pow(shapes - 1.0) hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) # Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t)) # NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t) nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1) if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = shapes.new_zeros(()) if self.lambda_reg > 0: reg = self.lambda_reg * ( (torch.log(scales + eps) ** 2).mean() + (torch.log(shapes + eps) ** 2).mean() ) return nll, reg