import math from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Pair extraction (utility; not used by the losses below) # ============================================================ def get_valid_pairs_and_dt( event_seqs: torch.Tensor, time_seqs: torch.Tensor, n_tech_tokens: int ) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Extract valid event pairs (prev -> next) and compute dt in years. Args: event_seqs (torch.Tensor): Event sequences. time_seqs (torch.Tensor): Time sequences. n_tech_tokens (int): Number of technical tokens. Returns: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: (dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None. Notes: - Assumes strict right-padding. - Filters to next events that are disease tokens: token_id >= n_tech_tokens. - Filters to strictly positive dt. """ real_mask = event_seqs >= 1 idx = real_mask.nonzero(as_tuple=False) if idx.size(0) <= 1: return None same_batch = idx[1:, 0] == idx[:-1, 0] if not same_batch.any(): return None prev_idx = idx[:-1][same_batch] next_idx = idx[1:][same_batch] b_next, t_next = next_idx[:, 0], next_idx[:, 1] valid_target = event_seqs[b_next, t_next] >= n_tech_tokens if not valid_target.any(): return None prev_idx = prev_idx[valid_target] next_idx = next_idx[valid_target] b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1] b_next, t_next = next_idx[:, 0], next_idx[:, 1] dt = (time_seqs[b_next, t_next] - time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25 valid_dt = dt > 0 if not valid_dt.any(): return None dt = dt[valid_dt] b_prev = b_prev[valid_dt] t_prev = t_prev[valid_dt] b_next = b_next[valid_dt] t_next = t_next[valid_dt] return dt, b_prev, t_prev, b_next, t_next # ============================================================ # Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization) # ============================================================ class ExponentialNLLLoss(nn.Module): """ Competing risks exponential likelihood. The negative log-likelihood is given by: .. math:: \\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt Args: eps (float): Small epsilon for numerical stability. """ def __init__( self, lambda_reg: float = 0.0, eps: float = 1e-6, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K) tensor of logits. target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K). dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0. """ logits = logits.squeeze(-1) if logits.dim() == 3 else logits hazards = F.softplus(logits) + self.eps # (M,K) hazard_event = hazards.gather( 1, target_events.unsqueeze(1)).squeeze(1) # (M,) total_hazard = hazards.sum(dim=1) # (M,) log_hazards = torch.log(hazards) # (M, K) nll = -torch.log(hazard_event) + total_hazard * dt if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = F.cross_entropy(log_hazards, target_events, reduction="mean") * self.lambda_reg return nll, reg class WeibullNLLLoss(nn.Module): """ Weibull hazard in t. .. math:: \\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k} \\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1} Args: eps (float): Small epsilon for numerical stability. lambda_reg (float): Regularization weight. use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples. near_integer_eps_years (float): Near-integer threshold in years. interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years. min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year. """ def __init__( self, eps: float = 1e-6, lambda_reg: float = 0.0, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K, 2) tensor of logits. target_events (torch.Tensor): (M,) tensor of target events. dt (torch.Tensor): (M,) tensor of time intervals. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization). """ shapes = F.softplus(logits[..., 0]) + self.eps # (M,K) scales = F.softplus(logits[..., 1]) + self.eps # (M,K) eps = self.eps t = torch.clamp(dt, min=eps) t_mat = t.unsqueeze(1) # (M,1) # cumulative hazard (M,K) cum_hazard = scales * t_mat.pow(shapes) # hazard (M,K) hazard = shapes * scales * t_mat.pow(shapes - 1.0) hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) # Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t)) # NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t) nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1) if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = shapes.new_zeros(()) if self.lambda_reg > 0: reg = self.lambda_reg * ( (torch.log(scales + eps) ** 2).mean() + (torch.log(shapes + eps) ** 2).mean() ) return nll, reg