import math from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Pair extraction (utility; not used by the losses below) # ============================================================ def get_valid_pairs_and_dt( event_seqs: torch.Tensor, time_seqs: torch.Tensor, n_tech_tokens: int ) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Extract valid event pairs (prev -> next) and compute dt in years. Args: event_seqs (torch.Tensor): Event sequences. time_seqs (torch.Tensor): Time sequences. n_tech_tokens (int): Number of technical tokens. Returns: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: (dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None. Notes: - Assumes strict right-padding. - Filters to next events that are disease tokens: token_id >= n_tech_tokens. - Filters to strictly positive dt. """ real_mask = event_seqs >= 1 idx = real_mask.nonzero(as_tuple=False) if idx.size(0) <= 1: return None same_batch = idx[1:, 0] == idx[:-1, 0] if not same_batch.any(): return None prev_idx = idx[:-1][same_batch] next_idx = idx[1:][same_batch] b_next, t_next = next_idx[:, 0], next_idx[:, 1] valid_target = event_seqs[b_next, t_next] >= n_tech_tokens if not valid_target.any(): return None prev_idx = prev_idx[valid_target] next_idx = next_idx[valid_target] b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1] b_next, t_next = next_idx[:, 0], next_idx[:, 1] dt = (time_seqs[b_next, t_next] - time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25 valid_dt = dt > 0 if not valid_dt.any(): return None dt = dt[valid_dt] b_prev = b_prev[valid_dt] t_prev = t_prev[valid_dt] b_next = b_next[valid_dt] t_next = t_next[valid_dt] return dt, b_prev, t_prev, b_next, t_next # ============================================================ # Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization) # ============================================================ class ExponentialNLLLoss(nn.Module): """ Competing risks exponential likelihood. The negative log-likelihood is given by: .. math:: \\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt Args: eps (float): Small epsilon for numerical stability. """ def __init__( self, lambda_reg: float = 0.0, eps: float = 1e-6, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K) tensor of logits. target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K). dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0. """ logits = logits.squeeze(-1) if logits.dim() == 3 else logits hazards = F.softplus(logits) + self.eps # (M,K) hazard_event = hazards.gather( 1, target_events.unsqueeze(1)).squeeze(1) # (M,) total_hazard = hazards.sum(dim=1) # (M,) log_hazards = torch.log(hazards) # (M, K) nll = -torch.log(hazard_event) + total_hazard * dt if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = F.cross_entropy(log_hazards, target_events, reduction="mean") * self.lambda_reg return nll, reg class DiscreteTimeCIFNLLLoss(nn.Module): """Direct discrete-time CIF negative log-likelihood (no censoring). This loss assumes the model outputs per-bin logits over (K causes + 1 complement) channels, where the complement channel (index K) represents survival across bins. Per-sample likelihood for observed cause k at time bin j: p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j) Args: bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0. eps: Unused; kept for interface compatibility / future numerical tweaks. lambda_reg: Optional regularization strength. """ def __init__( self, bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") if float(bin_edges[0]) != 0.0: raise ValueError("bin_edges[0] must equal 0") for i in range(1, len(bin_edges)): if not (float(bin_edges[i]) > float(bin_edges[i - 1])): raise ValueError("bin_edges must be strictly increasing") self.eps = float(eps) self.lambda_reg = float(lambda_reg) self.register_buffer( "bin_edges", torch.tensor(bin_edges, dtype=torch.float32), persistent=False, ) def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if logits.ndim != 3: raise ValueError( f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}" ) if target_events.ndim != 1 or dt.ndim != 1: raise ValueError( f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}" ) if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: raise ValueError( "Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]" ) if reduction not in {"mean", "sum", "none"}: raise ValueError("reduction must be one of {'mean','sum','none'}") if not torch.all(dt > 0): raise ValueError("dt must be strictly positive") # Infer K and n_bins from logits and bin_edges. m, k_plus_1, n_bins_plus_1 = logits.shape k_comp = k_plus_1 - 1 if k_comp < 1: raise ValueError( "logits.shape[1] must be at least 2 (K>=1 plus complement channel)") n_bins = int(self.bin_edges.numel() - 1) if n_bins_plus_1 != n_bins + 1: raise ValueError( f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}" ) if target_events.dtype != torch.long: target_events = target_events.to(torch.long) if (target_events < 0).any() or (target_events >= k_comp).any(): raise ValueError( f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}" ) # Map continuous dt to discrete bins j in {1..n_bins}. bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) # (M,), may be n_bins+1 if dt > bin_edges[-1] time_bin = torch.bucketize(dt, bin_edges) time_bin = torch.clamp(time_bin, min=1, max=n_bins).to( torch.long) # ensure valid event bins # Log-probabilities across causes+complement for each bin. logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1) # Previous survival term: sum_{u=1}^{j-1} -log p(comp at u) bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,) mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze( 0) < time_bin.unsqueeze(1)) # (M, n_bins+1) logp_comp = logp[:, k_comp, :] # (M, n_bins+1) loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,) # Event term at bin j: -log p(k at j) m_idx = torch.arange(m, device=logits.device) loss_event = -logp[m_idx, target_events, time_bin] # (M,) loss = loss_prev + loss_event if reduction == "mean": nll = loss.mean() elif reduction == "sum": nll = loss.sum() else: nll = loss reg = torch.zeros((), device=logits.device, dtype=loss.dtype) if self.lambda_reg > 0.0: # Regularize the cause distribution at the event bin using NLL on log-probs. logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1) idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1) logp_at_event_bin = logp_causes.gather( dim=2, index=idx).squeeze(2) # (M, K) reg = self.lambda_reg * \ F.nll_loss(logp_at_event_bin, target_events, reduction="mean") return nll, reg class LogNormalBasisBinnedHazardCIFNLLLoss(nn.Module): r"""Route-3: continuous-time lognormal-basis hazards with discrete-time CIF likelihood. This implements a cause-specific continuous-time hazard model: \lambda_j(t) = \sum_r \alpha_{j,r} b_r(t) where b_r(t) is the lognormal PDF basis implied by a Normal on log-time. Training objective is IDENTICAL in structure to DiscreteTimeCIFNLLLoss, but per-bin categorical probabilities are derived from integrated hazards. Expected logits interface (preferred): logits: (B, J*R) reshaped to (B, J, R) For convenience/compatibility, also accepts: logits: (B, 1+J*R) and ignores the first column. Forward interface (must match): forward(logits, target_events, dt, reduction) -> (nll, reg) """ def __init__( self, bin_edges: Sequence[float], centers: Sequence[float], *, eps: float = 1e-8, alpha_floor: float = 0.0, bandwidth_init: float = 0.7, bandwidth_min: float = 1e-3, bandwidth_max: float = 10.0, lambda_sigma_reg: float = 0.0, sigma_reg_target: Optional[float] = None, lambda_reg: float = 0.0, ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") if float(bin_edges[0]) != 0.0: raise ValueError("bin_edges[0] must equal 0") # allow last edge to be +inf for i in range(1, len(bin_edges)): prev = float(bin_edges[i - 1]) cur = float(bin_edges[i]) if math.isinf(prev): raise ValueError( "bin_edges cannot have +inf except possibly as the last edge") if i == len(bin_edges) - 1 and math.isinf(cur): if not (cur > prev): raise ValueError("bin_edges must be strictly increasing") else: if not (cur > prev): raise ValueError("bin_edges must be strictly increasing") if len(centers) < 1: raise ValueError("centers must have length >= 1") self.eps = float(eps) self.alpha_floor = float(alpha_floor) self.bandwidth_min = float(bandwidth_min) self.bandwidth_max = float(bandwidth_max) self.bandwidth_init = float(bandwidth_init) self.lambda_sigma_reg = float(lambda_sigma_reg) self.sigma_reg_target = None if sigma_reg_target is None else float( sigma_reg_target) self.lambda_reg = float(lambda_reg) self.register_buffer( "bin_edges", torch.tensor([float(x) for x in bin_edges], dtype=torch.float32), persistent=False, ) self.register_buffer( "centers", torch.tensor([float(x) for x in centers], dtype=torch.float32), persistent=False, ) if self.bandwidth_init <= 0: raise ValueError("bandwidth_init must be > 0") self.log_sigma = nn.Parameter(torch.tensor( math.log(self.bandwidth_init), dtype=torch.float32)) @staticmethod def _normal_cdf(z: torch.Tensor) -> torch.Tensor: z = torch.clamp(z, -12.0, 12.0) return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) @staticmethod def _normal_sf(z: torch.Tensor) -> torch.Tensor: z = torch.clamp(z, -12.0, 12.0) return 0.5 * torch.erfc(z / math.sqrt(2.0)) def _compute_delta_basis_all_bins( self, *, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """Compute ΔB[k,r] for bins k=1..n_bins (shape: (n_bins, R)).""" bin_edges = self.bin_edges.to(device=device, dtype=dtype) centers = self.centers.to(device=device, dtype=dtype) n_bins = int(bin_edges.numel() - 1) if n_bins < 1: raise ValueError("bin_edges must define at least one bin") left = bin_edges[:-1] # (n_bins,) right = bin_edges[1:] # (n_bins,) if float(self.bin_edges[1]) > 0.0: t_min = float(self.bin_edges[1]) * 1e-6 else: t_min = 1e-12 t_min_t = torch.tensor(t_min, device=device, dtype=dtype) left_is_zero = left <= 0 left_clamped = torch.clamp(left, min=t_min_t) log_left = torch.log(left_clamped) is_inf = torch.isinf(right) right_safe = torch.where( is_inf, left_clamped, torch.clamp(right, min=t_min_t)) log_right = torch.log(right_safe) sigma = torch.clamp( self.log_sigma.to(device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max, ) z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma z_left = torch.clamp(z_left, -12.0, 12.0) z_right = torch.clamp(z_right, -12.0, 12.0) cdf_left = self._normal_cdf(z_left) if left_is_zero.any(): cdf_left = torch.where( left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left) cdf_right = self._normal_cdf(z_right) delta_finite = cdf_right - cdf_left # Last bin: ΔB = 1 - CDF(left) = SF(left), computed via erfc for stability. delta_last = self._normal_sf(z_left) if left_is_zero.any(): delta_last = torch.where( left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last) delta_basis = torch.where( is_inf.unsqueeze(-1), delta_last, delta_finite) delta_basis = torch.clamp(delta_basis, min=0.0) return delta_basis def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if logits.ndim not in {2, 3}: raise ValueError( f"logits must be 2D (B, J*R) (or (B, 1+J*R)) or 3D (B, J, R); got {tuple(logits.shape)}" ) if target_events.ndim != 1 or dt.ndim != 1: raise ValueError( f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}" ) if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: raise ValueError( "Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]" ) if reduction not in {"mean", "sum", "none"}: raise ValueError("reduction must be one of {'mean','sum','none'}") if not torch.all(dt > 0): raise ValueError("dt must be strictly positive") device = logits.device dtype = logits.dtype centers = self.centers.to(device=device, dtype=dtype) r = int(centers.numel()) if r < 1: raise ValueError("centers must have length >= 1") if logits.ndim == 3: if logits.shape[2] != r: raise ValueError( f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}" ) j = int(logits.shape[1]) if j < 1: raise ValueError("Inferred number of causes J must be >= 1") alpha = F.softplus(logits) + self.alpha_floor # (B, J, R) else: d = int(logits.shape[1]) offset = 0 if d % r == 0: jr = d elif (d - 1) % r == 0: offset = 1 jr = d - 1 else: raise ValueError( f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}" ) j = jr // r if j < 1: raise ValueError("Inferred number of causes J must be >= 1") logits_used = logits[:, offset:] alpha = F.softplus(logits_used).view(-1, j, r) + \ self.alpha_floor # (B, J, R) delta_basis = self._compute_delta_basis_all_bins( device=device, dtype=dtype, ) # (n_bins, R) n_bins = int(delta_basis.shape[0]) # H_{j,k} = sum_r alpha_{j,r} * ΔB_{k,r} h_jk = torch.einsum("mjr,kr->mjk", alpha, delta_basis) # (B, J, n_bins) h_k = h_jk.sum(dim=1) # (B, n_bins) # Map continuous dt to discrete event bin index k* in {1..n_bins}. bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) time_bin = torch.bucketize(dt, bin_edges) time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(torch.long) cause = target_events.to(device=device, dtype=torch.long) if (cause < 0).any() or (cause >= j).any(): raise ValueError(f"target_events must be in [0, J-1] where J={j}") # Previous survival term: sum_{u 0.0: # Regularize the within-bin cause competition via NLL on log ratios. # ratio_j = H_{j,k*} / H_{k*} h_event_all = h_jk[b_idx, :, k0] # (B, J) denom = torch.clamp(h_event_total, min=self.eps).unsqueeze(1) ratio = torch.clamp(h_event_all / denom, min=self.eps) log_ratio = torch.log(ratio) reg = reg + self.lambda_reg * F.nll_loss( log_ratio, cause, reduction="mean") if self.lambda_sigma_reg > 0.0: target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target sigma_penalty = (self.log_sigma.to( device=device, dtype=dtype) - math.log(float(target))) ** 2 reg = reg + sigma_penalty * self.lambda_sigma_reg return nll, reg