import math from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Pair extraction (utility; not used by the losses below) # ============================================================ def get_valid_pairs_and_dt( event_seqs: torch.Tensor, time_seqs: torch.Tensor, n_tech_tokens: int ) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Extract valid event pairs (prev -> next) and compute dt in years. Args: event_seqs (torch.Tensor): Event sequences. time_seqs (torch.Tensor): Time sequences. n_tech_tokens (int): Number of technical tokens. Returns: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: (dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None. Notes: - Assumes strict right-padding. - Filters to next events that are disease tokens: token_id >= n_tech_tokens. - Filters to strictly positive dt. """ real_mask = event_seqs >= 1 idx = real_mask.nonzero(as_tuple=False) if idx.size(0) <= 1: return None same_batch = idx[1:, 0] == idx[:-1, 0] if not same_batch.any(): return None prev_idx = idx[:-1][same_batch] next_idx = idx[1:][same_batch] b_next, t_next = next_idx[:, 0], next_idx[:, 1] valid_target = event_seqs[b_next, t_next] >= n_tech_tokens if not valid_target.any(): return None prev_idx = prev_idx[valid_target] next_idx = next_idx[valid_target] b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1] b_next, t_next = next_idx[:, 0], next_idx[:, 1] dt = (time_seqs[b_next, t_next] - time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25 valid_dt = dt > 0 if not valid_dt.any(): return None dt = dt[valid_dt] b_prev = b_prev[valid_dt] t_prev = t_prev[valid_dt] b_next = b_next[valid_dt] t_next = t_next[valid_dt] return dt, b_prev, t_prev, b_next, t_next # ============================================================ # Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization) # ============================================================ class ExponentialNLLLoss(nn.Module): """ Competing risks exponential likelihood. The negative log-likelihood is given by: .. math:: \\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt Args: eps (float): Small epsilon for numerical stability. """ def __init__( self, lambda_reg: float = 0.0, eps: float = 1e-6, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: logits (torch.Tensor): (M, K) tensor of logits. target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K). dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive. reduction (str): 'mean', 'sum', or 'none'. Returns: Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0. """ logits = logits.squeeze(-1) if logits.dim() == 3 else logits hazards = F.softplus(logits) + self.eps # (M,K) hazard_event = hazards.gather( 1, target_events.unsqueeze(1)).squeeze(1) # (M,) total_hazard = hazards.sum(dim=1) # (M,) log_hazards = torch.log(hazards) # (M, K) nll = -torch.log(hazard_event) + total_hazard * dt if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() reg = F.cross_entropy(log_hazards, target_events, reduction="mean") * self.lambda_reg return nll, reg def calculate_cifs( self, logits: torch.Tensor, taus: torch.Tensor, eps: Optional[float] = None, return_survival: bool = False, ): """Compute CIFs for a competing-risks exponential model. Model assumptions: - cause-specific hazards are constant in time within a sample. - hazards are obtained via softplus(logits) + eps. Args: logits: (M, K) or (M, K, 1) tensor. taus: scalar, (T,), (M,), or (M, T) times (>=0 recommended). eps: overrides self.eps for numerical stability. return_survival: if True, also return survival S(tau). Returns: cifs: (M, K) if taus is scalar or (M,), else (M, K, T). survival (optional): (M,) if taus is scalar or (M,), else (M, T). """ def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) scalar_out = False kind = "T" # one of: 'T', 'per_sample', 'MT' if t.ndim == 0: t = t.view(1) scalar_out = True t = t.view(1, 1) # (1,1) kind = "T" elif t.ndim == 1: if t.shape[0] == batch_size: t = t.view(batch_size, 1) # (M,1) kind = "per_sample" else: t = t.view(1, -1) # (1,T) kind = "T" elif t.ndim == 2: if t.shape[0] != batch_size: raise ValueError( f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" ) kind = "MT" else: raise ValueError( f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") return t, kind, scalar_out logits = logits.squeeze(-1) if logits.dim() == 3 else logits if logits.ndim != 2: raise ValueError( f"logits must be 2D (M,K) (or 3D with last dim 1); got shape={tuple(logits.shape)}") M, K = logits.shape used_eps = float(self.eps if eps is None else eps) hazards = F.softplus(logits) + used_eps # (M, K) total_hazard = hazards.sum(dim=1, keepdim=True) # (M, 1) total_hazard = torch.clamp(total_hazard, min=used_eps) frac = hazards / total_hazard # (M, K) taus_t, kind, scalar_out = _prepare_taus( taus, M, logits.device, hazards.dtype) taus_t = torch.clamp(taus_t, min=0) if kind == "T": # taus_t: (1,T) exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,T) cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) # (M,K,T) survival = torch.exp(-total_hazard * taus_t) # (M,T) else: # taus_t: (M,1) or (M,T) exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,1) or (M,T) # (M,K,1) or (M,K,T) cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) survival = torch.exp(-total_hazard * taus_t) # (M,1) or (M,T) if kind == "per_sample": cifs = cifs.squeeze(-1) # (M,K) survival = survival.squeeze(-1) # (M,) elif scalar_out: cifs = cifs.squeeze(-1) # (M,K) survival = survival.squeeze(-1) # (M,) return (cifs, survival) if return_survival else cifs class DiscreteTimeCIFNLLLoss(nn.Module): """Direct discrete-time CIF negative log-likelihood (no censoring). This loss assumes the model outputs per-bin logits over (K causes + 1 complement) channels, where the complement channel (index K) represents survival across bins. Per-sample likelihood for observed cause k at time bin j: p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j) Args: bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0. eps: Unused; kept for interface compatibility / future numerical tweaks. lambda_reg: Optional regularization strength. """ def __init__( self, bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") if float(bin_edges[0]) != 0.0: raise ValueError("bin_edges[0] must equal 0") for i in range(1, len(bin_edges)): if not (float(bin_edges[i]) > float(bin_edges[i - 1])): raise ValueError("bin_edges must be strictly increasing") self.eps = float(eps) self.lambda_reg = float(lambda_reg) self.register_buffer( "bin_edges", torch.tensor(bin_edges, dtype=torch.float32), persistent=False, ) def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if logits.ndim != 3: raise ValueError( f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}" ) if target_events.ndim != 1 or dt.ndim != 1: raise ValueError( f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}" ) if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: raise ValueError( "Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]" ) if reduction not in {"mean", "sum", "none"}: raise ValueError("reduction must be one of {'mean','sum','none'}") if not torch.all(dt > 0): raise ValueError("dt must be strictly positive") # Infer K and n_bins from logits and bin_edges. m, k_plus_1, n_bins_plus_1 = logits.shape k_comp = k_plus_1 - 1 if k_comp < 1: raise ValueError( "logits.shape[1] must be at least 2 (K>=1 plus complement channel)") n_bins = int(self.bin_edges.numel() - 1) if n_bins_plus_1 != n_bins + 1: raise ValueError( f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}" ) if target_events.dtype != torch.long: target_events = target_events.to(torch.long) if (target_events < 0).any() or (target_events >= k_comp).any(): raise ValueError( f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}" ) # Map continuous dt to discrete bins j in {1..n_bins}. bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) # (M,), may be n_bins+1 if dt > bin_edges[-1] time_bin = torch.bucketize(dt, bin_edges) time_bin = torch.clamp(time_bin, min=1, max=n_bins).to( torch.long) # ensure valid event bins # Log-probabilities across causes+complement for each bin. logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1) # Previous survival term: sum_{u=1}^{j-1} -log p(comp at u) bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,) mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze( 0) < time_bin.unsqueeze(1)) # (M, n_bins+1) logp_comp = logp[:, k_comp, :] # (M, n_bins+1) loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,) # Event term at bin j: -log p(k at j) m_idx = torch.arange(m, device=logits.device) loss_event = -logp[m_idx, target_events, time_bin] # (M,) loss = loss_prev + loss_event if reduction == "mean": nll = loss.mean() elif reduction == "sum": nll = loss.sum() else: nll = loss reg = torch.zeros((), device=logits.device, dtype=loss.dtype) if self.lambda_reg > 0.0: # Regularize the cause distribution at the event bin using NLL on log-probs. logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1) idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1) logp_at_event_bin = logp_causes.gather( dim=2, index=idx).squeeze(2) # (M, K) reg = self.lambda_reg * \ F.nll_loss(logp_at_event_bin, target_events, reduction="mean") return nll, reg def calculate_cifs( self, logits: torch.Tensor, taus: torch.Tensor, eps: Optional[float] = None, return_survival: bool = False, ): """Compute discrete-time CIFs implied by per-bin (K causes + complement) logits. This matches the likelihood used in forward(): p(event=cause k at bin j) = Π_{u=1}^{j-1} p(comp at u) * p(k at j) Args: logits: (M, K+1, n_bins+1) where channel K is complement. taus: scalar, (T,), (M,), or (M,T) continuous times. eps: unused (kept for signature compatibility). return_survival: if True, also return survival probability up to the mapped bin. Returns: cifs: (M, K) if taus is scalar or (M,), else (M, K, T). survival (optional): (M,) if taus is scalar or (M,), else (M, T). """ def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) scalar_out = False kind = "T" if t.ndim == 0: t = t.view(1) scalar_out = True t = t.view(1, 1) kind = "T" elif t.ndim == 1: if t.shape[0] == batch_size: t = t.view(batch_size, 1) kind = "per_sample" else: t = t.view(1, -1) kind = "T" elif t.ndim == 2: if t.shape[0] != batch_size: raise ValueError( f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" ) kind = "MT" else: raise ValueError( f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") return t, kind, scalar_out if logits.ndim != 3: raise ValueError( f"logits must have shape (M, K+1, n_bins+1); got {tuple(logits.shape)}" ) M, k_plus_1, n_bins_plus_1 = logits.shape K = k_plus_1 - 1 if K < 1: raise ValueError( "logits.shape[1] must be at least 2 (K>=1 plus complement)") n_bins = int(self.bin_edges.numel() - 1) if n_bins_plus_1 != n_bins + 1: raise ValueError( f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}" ) # probs over causes+complement per bin probs = F.softmax(logits, dim=1) # (M, K+1, n_bins+1) p_causes = probs[:, :K, 1:] # (M, K, n_bins) p_comp = probs[:, K, 1:] # (M, n_bins) # survival up to end of each bin (1..n_bins) surv_end = torch.cumprod(p_comp, dim=1) # (M, n_bins) ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype) surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins) inc = surv_start.unsqueeze(1) * p_causes # (M, K, n_bins) cif_full = torch.cumsum(inc, dim=2) # (M, K, n_bins) taus_t, kind, scalar_out = _prepare_taus( taus, M, logits.device, surv_end.dtype) taus_t = torch.clamp(taus_t, min=0) bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype) time_bin = torch.bucketize(taus_t, bin_edges) # (..) time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(torch.long) if kind == "T": # (1,T) -> expand to (M,T) time_bin = time_bin.expand(M, -1) # kind per_sample gives (M,1), MT gives (M,T) idx = torch.clamp(time_bin - 1, min=0) # (M,T) gathered_cif = cif_full.gather( dim=2, index=idx.unsqueeze(1).expand(-1, K, -1), ) # (M,K,T) gathered_surv = surv_end.gather(dim=1, index=idx) # (M,T) # tau mapped to bin 0 => CIF=0, survival=1 zero_mask = (time_bin == 0) if zero_mask.any(): gathered_cif = gathered_cif.masked_fill(zero_mask.unsqueeze(1), 0.0) gathered_surv = gathered_surv.masked_fill(zero_mask, 1.0) if kind == "per_sample": gathered_cif = gathered_cif.squeeze(-1) # (M,K) gathered_surv = gathered_surv.squeeze(-1) # (M,) elif scalar_out: gathered_cif = gathered_cif.squeeze(-1) # (M,K) gathered_surv = gathered_surv.squeeze(-1) # (M,) return (gathered_cif, gathered_surv) if return_survival else gathered_cif class PiecewiseExponentialCIFNLLLoss(nn.Module): """ Piecewise-Exponential (PWE) cause-specific hazards with discrete-time CIF likelihood. - No censoring - No regularization (reg always 0) - Forward signature matches DiscreteTimeCIFNLLLoss: forward(logits, target_events, dt, reduction) -> (nll, reg) Expected shapes: logits: (M, K, n_bins) # hazard logits per cause per bin target_events: (M,) long in [0, K-1] dt: (M,) event times (strictly > 0) bin_edges: length n_bins+1, strictly increasing, bin_edges[0]==0, and MUST be finite at the last edge (no +inf) for PWE. """ def __init__( self, bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, # kept for signature compatibility; UNUSED ): super().__init__() if len(bin_edges) < 2: raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") if float(bin_edges[0]) != 0.0: raise ValueError("bin_edges[0] must equal 0.0") for i in range(1, len(bin_edges)): if not (float(bin_edges[i]) > float(bin_edges[i - 1])): raise ValueError("bin_edges must be strictly increasing") if math.isinf(float(bin_edges[-1])): raise ValueError( "PiecewiseExponentialCIFNLLLoss requires a finite last bin edge (no +inf). " "Use a finite truncation horizon for PWE." ) self.eps = float(eps) # unused, kept only for interface compatibility self.lambda_reg = float(lambda_reg) self.register_buffer( "bin_edges", torch.tensor([float(x) for x in bin_edges], dtype=torch.float32), persistent=False, ) def forward( self, logits: torch.Tensor, target_events: torch.Tensor, dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: if reduction not in {"mean", "sum", "none"}: raise ValueError("reduction must be one of {'mean','sum','none'}") if logits.ndim != 3: raise ValueError( f"logits must be 3D (M, K, n_bins); got shape={tuple(logits.shape)}") if target_events.ndim != 1 or dt.ndim != 1: raise ValueError("target_events and dt must be 1D tensors") if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: raise ValueError( "Batch size mismatch among logits, target_events, dt") if not torch.all(dt > 0): raise ValueError( "dt must be strictly positive (no censoring supported here)") M, K, n_bins = logits.shape if target_events.dtype != torch.long: target_events = target_events.to(torch.long) if (target_events < 0).any() or (target_events >= K).any(): raise ValueError(f"target_events must be in [0, {K-1}]") # Prepare bin_edges / bin widths bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) if bin_edges.numel() != n_bins + 1: raise ValueError( f"bin_edges length must be n_bins+1={n_bins+1}; got {bin_edges.numel()}" ) dt_bins = (bin_edges[1:] - bin_edges[:-1] ).to(device=logits.device, dtype=logits.dtype) # (n_bins,) if not torch.isfinite(dt_bins).all(): raise ValueError("All bin widths must be finite for PWE.") if not (dt_bins > 0).all(): raise ValueError( "All bin widths must be strictly positive for PWE.") # Map event time -> bin index k* in {1..n_bins} # (same convention as your discrete_time_cif: clamp to [1, n_bins]) time_bin = torch.bucketize(dt, bin_edges) time_bin = torch.clamp( time_bin, min=1, max=n_bins).to(torch.long) # (M,) k0 = time_bin - 1 # 0..n_bins-1 # Nonnegative hazards per cause per bin hazards = F.softplus(logits) + self.eps # (M, K, n_bins) # Integrated hazards H_{j,k} = lambda_{j,k} * Δt_k H_jk = hazards * dt_bins.view(1, 1, n_bins) # (M, K, n_bins) H_k = H_jk.sum(dim=1) # (M, n_bins) # Previous survival term: Σ_{u 0.0 and n_bins >= 3: log_h = torch.log(hazards) # (M, K, n_bins) d2 = log_h[:, :, 2:] - 2.0 * log_h[:, :, 1:-1] + \ log_h[:, :, :-2] # (M, K, n_bins-2) reg = self.lambda_reg * (d2.pow(2).mean()) else: reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype) return nll, reg def calculate_cifs( self, logits: torch.Tensor, taus: torch.Tensor, eps: Optional[float] = None, return_survival: bool = False, ): """Compute CIFs for piecewise-constant cause-specific hazards. Uses the same binning convention as forward(): taus are mapped to a bin via torch.bucketize(taus, bin_edges), clamped to [0, n_bins]. tau<=0 maps to 0. Args: logits: (M, K, n_bins) hazard logits per cause per bin. taus: scalar, (T,), (M,), or (M,T) times. eps: overrides self.eps for numerical stability. return_survival: if True, also return survival S(tau). Returns: cifs: (M, K) if taus is scalar or (M,), else (M, K, T). survival (optional): (M,) if taus is scalar or (M,), else (M, T). """ def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) scalar_out = False kind = "T" if t.ndim == 0: t = t.view(1) scalar_out = True t = t.view(1, 1) kind = "T" elif t.ndim == 1: if t.shape[0] == batch_size: t = t.view(batch_size, 1) kind = "per_sample" else: t = t.view(1, -1) kind = "T" elif t.ndim == 2: if t.shape[0] != batch_size: raise ValueError( f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" ) kind = "MT" else: raise ValueError( f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") return t, kind, scalar_out if logits.ndim != 3: raise ValueError( f"logits must be 3D (M,K,n_bins); got shape={tuple(logits.shape)}") M, K, n_bins = logits.shape if self.bin_edges.numel() != n_bins + 1: raise ValueError( f"bin_edges length must be n_bins+1={n_bins+1}; got {self.bin_edges.numel()}" ) used_eps = float(self.eps if eps is None else eps) taus_t, kind, scalar_out = _prepare_taus( taus, M, logits.device, logits.dtype) taus_t = torch.clamp(taus_t, min=0) bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype) dt_bins = (bin_edges[1:] - bin_edges[:-1] ).to(device=logits.device, dtype=logits.dtype) # (n_bins,) hazards = F.softplus(logits) + used_eps # (M, K, n_bins) total_h = hazards.sum(dim=1) # (M, n_bins) total_h = torch.clamp(total_h, min=used_eps) # Precompute full-bin CIF increments H_total_bin = total_h * dt_bins.view(1, n_bins) # (M, n_bins) cum_H_end = torch.cumsum(H_total_bin, dim=1) # (M, n_bins) surv_end = torch.exp(-cum_H_end) # (M, n_bins) ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype) surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins) frac = hazards / total_h.unsqueeze(1) # (M, K, n_bins) one_minus = 1.0 - \ torch.exp(-total_h * dt_bins.view(1, n_bins)) # (M, n_bins) inc_full = surv_start.unsqueeze( 1) * frac * one_minus.unsqueeze(1) # (M, K, n_bins) cif_full = torch.cumsum(inc_full, dim=2) # (M, K, n_bins) # Map taus -> bin index b in [0..n_bins] time_bin = torch.bucketize(taus_t, bin_edges) time_bin = torch.clamp(time_bin, min=0, max=n_bins).to( torch.long) # (...) if kind == "T": time_bin = time_bin.expand(M, -1) # (M,T) # Compute within-bin length l and indices b = time_bin # (M,T) idx_bin0 = torch.clamp(b - 1, min=0) # 0..n_bins-1 # Start-of-bin survival for the current bin (for b==0 it's unused) S_start_b = surv_start.gather(dim=1, index=idx_bin0) # (M,T) # Length into bin: l = tau - edge[b-1], clamped to [0, dt_bin] left_edge = bin_edges.gather( dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0).to(taus_t.dtype) l = taus_t.expand_as(b) - left_edge l = torch.clamp(l, min=0) width_b = dt_bins.gather( dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0) l = torch.min(l, width_b.to(l.dtype)) # CIF up to previous full bins # if b<=1 => 0 else cif_full at (b-2) prev_idx = torch.clamp(b - 2, min=0) cif_before = cif_full.gather( dim=2, index=prev_idx.unsqueeze(1).expand(-1, K, -1), ) # (M,K,T) if (b <= 1).any(): cif_before = cif_before.masked_fill((b <= 1).unsqueeze(1), 0.0) # Partial increment in current bin total_h_b = total_h.gather(dim=1, index=idx_bin0) # (M,T) haz_b = hazards.gather( dim=2, index=idx_bin0.unsqueeze(1).expand(-1, K, -1), ) # (M,K,T) frac_b = haz_b / total_h_b.unsqueeze(1) # (M,K,T) one_minus_partial = 1.0 - torch.exp(-total_h_b * l) # (M,T) inc_partial = S_start_b.unsqueeze( 1) * frac_b * one_minus_partial.unsqueeze(1) # (M,K,T) cifs = cif_before + inc_partial survival = S_start_b * torch.exp(-total_h_b * l) # (M,T) # Inference-only tail extension beyond the last finite edge. # For tau > t_B (t_B = bin_edges[-1]), extend survival and CIFs using # constant hazards from the final bin B: # S(tau)=S(t_B) * exp(-Λ_B * (tau - t_B)) # F_k(tau)=F_k(t_B) + S(t_B) * (λ_{k,B}/Λ_B) * (1 - exp(-Λ_B*(tau-t_B))) last_edge = bin_edges[-1] tau_full = taus_t.expand_as(b) # (M,T) tail_mask = tau_full > last_edge if tail_mask.any(): delta = torch.clamp(tau_full - last_edge, min=0) # (M,T) S_B = surv_end[:, -1].unsqueeze(1) # (M,1) F_B = cif_full[:, :, -1].unsqueeze(-1) # (M,K,1) lambda_last = hazards[:, :, -1] # (M,K) Lambda_last = torch.clamp( total_h[:, -1], min=used_eps).unsqueeze(1) # (M,1) exp_tail = torch.exp(-Lambda_last * delta) # (M,T) survival_tail = S_B * exp_tail # (M,T) cifs_tail = F_B + \ S_B.unsqueeze( 1) * (lambda_last / Lambda_last).unsqueeze(-1) * (1.0 - exp_tail).unsqueeze(1) survival = torch.where(tail_mask, survival_tail, survival) cifs = torch.where(tail_mask.unsqueeze(1), cifs_tail, cifs) # tau mapped to bin 0 => CIF=0, survival=1 zero_mask = (b == 0) if zero_mask.any(): cifs = cifs.masked_fill(zero_mask.unsqueeze(1), 0.0) survival = survival.masked_fill(zero_mask, 1.0) if kind == "per_sample": cifs = cifs.squeeze(-1) # (M,K) survival = survival.squeeze(-1) # (M,) elif scalar_out: cifs = cifs.squeeze(-1) # (M,K) survival = survival.squeeze(-1) # (M,) return (cifs, survival) if return_survival else cifs