2026-01-07 21:32:00 +08:00
|
|
|
import math
|
|
|
|
|
from typing import Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Pair extraction (utility; not used by the losses below)
|
|
|
|
|
# ============================================================
|
|
|
|
|
def get_valid_pairs_and_dt(
|
|
|
|
|
event_seqs: torch.Tensor,
|
|
|
|
|
time_seqs: torch.Tensor,
|
|
|
|
|
n_tech_tokens: int
|
|
|
|
|
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract valid event pairs (prev -> next) and compute dt in years.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
event_seqs (torch.Tensor): Event sequences.
|
|
|
|
|
time_seqs (torch.Tensor): Time sequences.
|
|
|
|
|
n_tech_tokens (int): Number of technical tokens.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
- Assumes strict right-padding.
|
|
|
|
|
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
|
|
|
|
|
- Filters to strictly positive dt.
|
|
|
|
|
"""
|
|
|
|
|
real_mask = event_seqs >= 1
|
|
|
|
|
idx = real_mask.nonzero(as_tuple=False)
|
|
|
|
|
|
|
|
|
|
if idx.size(0) <= 1:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
same_batch = idx[1:, 0] == idx[:-1, 0]
|
|
|
|
|
if not same_batch.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = idx[:-1][same_batch]
|
|
|
|
|
next_idx = idx[1:][same_batch]
|
|
|
|
|
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
|
|
|
|
|
if not valid_target.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = prev_idx[valid_target]
|
|
|
|
|
next_idx = next_idx[valid_target]
|
|
|
|
|
|
|
|
|
|
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
|
|
|
|
|
dt = (time_seqs[b_next, t_next] -
|
|
|
|
|
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
|
|
|
|
|
valid_dt = dt > 0
|
|
|
|
|
if not valid_dt.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
dt = dt[valid_dt]
|
|
|
|
|
b_prev = b_prev[valid_dt]
|
|
|
|
|
t_prev = t_prev[valid_dt]
|
|
|
|
|
b_next = b_next[valid_dt]
|
|
|
|
|
t_next = t_next[valid_dt]
|
|
|
|
|
|
|
|
|
|
return dt, b_prev, t_prev, b_next, t_next
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
|
|
|
|
|
# ============================================================
|
|
|
|
|
class ExponentialNLLLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Competing risks exponential likelihood.
|
|
|
|
|
|
|
|
|
|
The negative log-likelihood is given by:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
eps (float): Small epsilon for numerical stability.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.lambda_reg = lambda_reg
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Forward pass.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits (torch.Tensor): (M, K) tensor of logits.
|
|
|
|
|
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
|
|
|
|
|
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
|
|
|
|
|
reduction (str): 'mean', 'sum', or 'none'.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
|
|
|
|
|
"""
|
|
|
|
|
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
|
|
|
|
hazards = F.softplus(logits) + self.eps # (M,K)
|
|
|
|
|
hazard_event = hazards.gather(
|
|
|
|
|
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
|
|
|
|
|
total_hazard = hazards.sum(dim=1) # (M,)
|
|
|
|
|
log_hazards = torch.log(hazards) # (M, K)
|
|
|
|
|
nll = -torch.log(hazard_event) + total_hazard * dt
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = nll.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll = nll.sum()
|
|
|
|
|
|
|
|
|
|
reg = F.cross_entropy(log_hazards, target_events,
|
|
|
|
|
reduction="mean") * self.lambda_reg
|
|
|
|
|
return nll, reg
|
|
|
|
|
|
2026-01-16 14:55:09 +08:00
|
|
|
def calculate_cifs(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
taus: torch.Tensor,
|
|
|
|
|
eps: Optional[float] = None,
|
|
|
|
|
return_survival: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""Compute CIFs for a competing-risks exponential model.
|
|
|
|
|
|
|
|
|
|
Model assumptions:
|
|
|
|
|
- cause-specific hazards are constant in time within a sample.
|
|
|
|
|
- hazards are obtained via softplus(logits) + eps.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits: (M, K) or (M, K, 1) tensor.
|
|
|
|
|
taus: scalar, (T,), (M,), or (M, T) times (>=0 recommended).
|
|
|
|
|
eps: overrides self.eps for numerical stability.
|
|
|
|
|
return_survival: if True, also return survival S(tau).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
|
|
|
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
|
|
|
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
|
|
|
|
scalar_out = False
|
|
|
|
|
kind = "T" # one of: 'T', 'per_sample', 'MT'
|
|
|
|
|
if t.ndim == 0:
|
|
|
|
|
t = t.view(1)
|
|
|
|
|
scalar_out = True
|
|
|
|
|
t = t.view(1, 1) # (1,1)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 1:
|
|
|
|
|
if t.shape[0] == batch_size:
|
|
|
|
|
t = t.view(batch_size, 1) # (M,1)
|
|
|
|
|
kind = "per_sample"
|
|
|
|
|
else:
|
|
|
|
|
t = t.view(1, -1) # (1,T)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 2:
|
|
|
|
|
if t.shape[0] != batch_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
|
|
|
|
)
|
|
|
|
|
kind = "MT"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
|
|
|
|
return t, kind, scalar_out
|
|
|
|
|
|
|
|
|
|
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
|
|
|
|
if logits.ndim != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits must be 2D (M,K) (or 3D with last dim 1); got shape={tuple(logits.shape)}")
|
|
|
|
|
|
|
|
|
|
M, K = logits.shape
|
|
|
|
|
used_eps = float(self.eps if eps is None else eps)
|
|
|
|
|
|
|
|
|
|
hazards = F.softplus(logits) + used_eps # (M, K)
|
|
|
|
|
total_hazard = hazards.sum(dim=1, keepdim=True) # (M, 1)
|
|
|
|
|
total_hazard = torch.clamp(total_hazard, min=used_eps)
|
|
|
|
|
|
|
|
|
|
frac = hazards / total_hazard # (M, K)
|
|
|
|
|
|
|
|
|
|
taus_t, kind, scalar_out = _prepare_taus(
|
|
|
|
|
taus, M, logits.device, hazards.dtype)
|
|
|
|
|
taus_t = torch.clamp(taus_t, min=0)
|
|
|
|
|
|
|
|
|
|
if kind == "T":
|
|
|
|
|
# taus_t: (1,T)
|
|
|
|
|
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,T)
|
|
|
|
|
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) # (M,K,T)
|
|
|
|
|
survival = torch.exp(-total_hazard * taus_t) # (M,T)
|
|
|
|
|
else:
|
|
|
|
|
# taus_t: (M,1) or (M,T)
|
|
|
|
|
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
|
|
|
|
# (M,K,1) or (M,K,T)
|
|
|
|
|
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1)
|
|
|
|
|
survival = torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
|
|
|
|
|
|
|
|
|
if kind == "per_sample":
|
|
|
|
|
cifs = cifs.squeeze(-1) # (M,K)
|
|
|
|
|
survival = survival.squeeze(-1) # (M,)
|
|
|
|
|
elif scalar_out:
|
|
|
|
|
cifs = cifs.squeeze(-1) # (M,K)
|
|
|
|
|
survival = survival.squeeze(-1) # (M,)
|
|
|
|
|
|
|
|
|
|
return (cifs, survival) if return_survival else cifs
|
|
|
|
|
|
2026-01-07 21:32:00 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
class DiscreteTimeCIFNLLLoss(nn.Module):
|
|
|
|
|
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
|
|
|
|
|
|
|
|
|
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
|
|
|
|
|
channels, where the complement channel (index K) represents survival across bins.
|
|
|
|
|
|
2026-01-09 18:34:01 +08:00
|
|
|
Per-sample likelihood for observed cause k at time bin j:
|
|
|
|
|
p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j)
|
2026-01-09 18:31:38 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
|
|
|
|
|
eps: Unused; kept for interface compatibility / future numerical tweaks.
|
|
|
|
|
lambda_reg: Optional regularization strength.
|
2026-01-08 12:45:31 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
bin_edges: Sequence[float],
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if len(bin_edges) < 2:
|
2026-01-09 18:31:38 +08:00
|
|
|
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
|
|
|
|
if float(bin_edges[0]) != 0.0:
|
|
|
|
|
raise ValueError("bin_edges[0] must equal 0")
|
2026-01-08 12:45:31 +08:00
|
|
|
for i in range(1, len(bin_edges)):
|
2026-01-09 18:31:38 +08:00
|
|
|
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
2026-01-08 12:45:31 +08:00
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
|
|
|
|
|
self.eps = float(eps)
|
|
|
|
|
self.lambda_reg = float(lambda_reg)
|
2026-01-09 18:31:38 +08:00
|
|
|
self.register_buffer(
|
|
|
|
|
"bin_edges",
|
|
|
|
|
torch.tensor(bin_edges, dtype=torch.float32),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
2026-01-08 12:45:31 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2026-01-09 18:31:38 +08:00
|
|
|
if logits.ndim != 3:
|
2026-01-08 12:45:31 +08:00
|
|
|
raise ValueError(
|
2026-01-09 18:31:38 +08:00
|
|
|
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
2026-01-08 12:45:31 +08:00
|
|
|
)
|
2026-01-09 18:31:38 +08:00
|
|
|
if target_events.ndim != 1 or dt.ndim != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
|
|
|
|
)
|
|
|
|
|
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
|
|
|
|
)
|
|
|
|
|
if reduction not in {"mean", "sum", "none"}:
|
|
|
|
|
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if not torch.all(dt > 0):
|
2026-01-09 13:06:43 +08:00
|
|
|
raise ValueError("dt must be strictly positive")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Infer K and n_bins from logits and bin_edges.
|
|
|
|
|
m, k_plus_1, n_bins_plus_1 = logits.shape
|
|
|
|
|
k_comp = k_plus_1 - 1
|
|
|
|
|
if k_comp < 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
n_bins = int(self.bin_edges.numel() - 1)
|
|
|
|
|
if n_bins_plus_1 != n_bins + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
|
|
|
|
)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if target_events.dtype != torch.long:
|
|
|
|
|
target_events = target_events.to(torch.long)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if (target_events < 0).any() or (target_events >= k_comp).any():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
|
|
|
|
|
)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Map continuous dt to discrete bins j in {1..n_bins}.
|
|
|
|
|
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
|
|
|
|
# (M,), may be n_bins+1 if dt > bin_edges[-1]
|
|
|
|
|
time_bin = torch.bucketize(dt, bin_edges)
|
|
|
|
|
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
|
|
|
|
|
torch.long) # ensure valid event bins
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Log-probabilities across causes+complement for each bin.
|
|
|
|
|
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
|
|
|
|
|
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
|
|
|
|
|
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
|
|
|
|
|
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
|
|
|
|
|
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
|
|
|
|
|
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Event term at bin j: -log p(k at j)
|
|
|
|
|
m_idx = torch.arange(m, device=logits.device)
|
|
|
|
|
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
loss = loss_prev + loss_event
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = loss.mean()
|
2026-01-08 12:45:31 +08:00
|
|
|
elif reduction == "sum":
|
2026-01-09 18:31:38 +08:00
|
|
|
nll = loss.sum()
|
2026-01-08 12:45:31 +08:00
|
|
|
else:
|
2026-01-09 18:31:38 +08:00
|
|
|
nll = loss
|
|
|
|
|
|
|
|
|
|
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
|
|
|
|
|
if self.lambda_reg > 0.0:
|
|
|
|
|
# Regularize the cause distribution at the event bin using NLL on log-probs.
|
|
|
|
|
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
|
|
|
|
|
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
|
|
|
|
|
logp_at_event_bin = logp_causes.gather(
|
|
|
|
|
dim=2, index=idx).squeeze(2) # (M, K)
|
|
|
|
|
reg = self.lambda_reg * \
|
|
|
|
|
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
|
|
|
|
return nll, reg
|
2026-01-15 11:36:24 +08:00
|
|
|
|
2026-01-16 14:55:09 +08:00
|
|
|
def calculate_cifs(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
taus: torch.Tensor,
|
|
|
|
|
eps: Optional[float] = None,
|
|
|
|
|
return_survival: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""Compute discrete-time CIFs implied by per-bin (K causes + complement) logits.
|
|
|
|
|
|
|
|
|
|
This matches the likelihood used in forward():
|
|
|
|
|
p(event=cause k at bin j) = Π_{u=1}^{j-1} p(comp at u) * p(k at j)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits: (M, K+1, n_bins+1) where channel K is complement.
|
|
|
|
|
taus: scalar, (T,), (M,), or (M,T) continuous times.
|
|
|
|
|
eps: unused (kept for signature compatibility).
|
|
|
|
|
return_survival: if True, also return survival probability up to the mapped bin.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
|
|
|
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
|
|
|
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
|
|
|
|
scalar_out = False
|
|
|
|
|
kind = "T"
|
|
|
|
|
if t.ndim == 0:
|
|
|
|
|
t = t.view(1)
|
|
|
|
|
scalar_out = True
|
|
|
|
|
t = t.view(1, 1)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 1:
|
|
|
|
|
if t.shape[0] == batch_size:
|
|
|
|
|
t = t.view(batch_size, 1)
|
|
|
|
|
kind = "per_sample"
|
|
|
|
|
else:
|
|
|
|
|
t = t.view(1, -1)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 2:
|
|
|
|
|
if t.shape[0] != batch_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
|
|
|
|
)
|
|
|
|
|
kind = "MT"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
|
|
|
|
return t, kind, scalar_out
|
|
|
|
|
|
|
|
|
|
if logits.ndim != 3:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits must have shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
M, k_plus_1, n_bins_plus_1 = logits.shape
|
|
|
|
|
K = k_plus_1 - 1
|
|
|
|
|
if K < 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"logits.shape[1] must be at least 2 (K>=1 plus complement)")
|
|
|
|
|
|
|
|
|
|
n_bins = int(self.bin_edges.numel() - 1)
|
|
|
|
|
if n_bins_plus_1 != n_bins + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# probs over causes+complement per bin
|
|
|
|
|
probs = F.softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
|
|
|
|
p_causes = probs[:, :K, 1:] # (M, K, n_bins)
|
|
|
|
|
p_comp = probs[:, K, 1:] # (M, n_bins)
|
|
|
|
|
|
|
|
|
|
# survival up to end of each bin (1..n_bins)
|
|
|
|
|
surv_end = torch.cumprod(p_comp, dim=1) # (M, n_bins)
|
|
|
|
|
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
|
|
|
|
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
|
|
|
|
|
|
|
|
|
inc = surv_start.unsqueeze(1) * p_causes # (M, K, n_bins)
|
|
|
|
|
cif_full = torch.cumsum(inc, dim=2) # (M, K, n_bins)
|
|
|
|
|
|
|
|
|
|
taus_t, kind, scalar_out = _prepare_taus(
|
|
|
|
|
taus, M, logits.device, surv_end.dtype)
|
|
|
|
|
taus_t = torch.clamp(taus_t, min=0)
|
|
|
|
|
|
|
|
|
|
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
|
|
|
|
time_bin = torch.bucketize(taus_t, bin_edges) # (..)
|
|
|
|
|
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(torch.long)
|
|
|
|
|
|
|
|
|
|
if kind == "T":
|
|
|
|
|
# (1,T) -> expand to (M,T)
|
|
|
|
|
time_bin = time_bin.expand(M, -1)
|
|
|
|
|
# kind per_sample gives (M,1), MT gives (M,T)
|
|
|
|
|
|
|
|
|
|
idx = torch.clamp(time_bin - 1, min=0) # (M,T)
|
|
|
|
|
|
|
|
|
|
gathered_cif = cif_full.gather(
|
|
|
|
|
dim=2,
|
|
|
|
|
index=idx.unsqueeze(1).expand(-1, K, -1),
|
|
|
|
|
) # (M,K,T)
|
|
|
|
|
gathered_surv = surv_end.gather(dim=1, index=idx) # (M,T)
|
|
|
|
|
|
|
|
|
|
# tau mapped to bin 0 => CIF=0, survival=1
|
|
|
|
|
zero_mask = (time_bin == 0)
|
|
|
|
|
if zero_mask.any():
|
|
|
|
|
gathered_cif = gathered_cif.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
|
|
|
|
gathered_surv = gathered_surv.masked_fill(zero_mask, 1.0)
|
|
|
|
|
|
|
|
|
|
if kind == "per_sample":
|
|
|
|
|
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
|
|
|
|
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
|
|
|
|
elif scalar_out:
|
|
|
|
|
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
|
|
|
|
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
|
|
|
|
|
|
|
|
|
return (gathered_cif, gathered_surv) if return_survival else gathered_cif
|
|
|
|
|
|
2026-01-15 11:36:24 +08:00
|
|
|
|
|
|
|
|
class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Piecewise-Exponential (PWE) cause-specific hazards with discrete-time CIF likelihood.
|
|
|
|
|
- No censoring
|
|
|
|
|
- No regularization (reg always 0)
|
|
|
|
|
- Forward signature matches DiscreteTimeCIFNLLLoss:
|
|
|
|
|
forward(logits, target_events, dt, reduction) -> (nll, reg)
|
|
|
|
|
|
|
|
|
|
Expected shapes:
|
|
|
|
|
logits: (M, K, n_bins) # hazard logits per cause per bin
|
|
|
|
|
target_events: (M,) long in [0, K-1]
|
|
|
|
|
dt: (M,) event times (strictly > 0)
|
|
|
|
|
|
|
|
|
|
bin_edges:
|
|
|
|
|
length n_bins+1, strictly increasing, bin_edges[0]==0,
|
|
|
|
|
and MUST be finite at the last edge (no +inf) for PWE.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
bin_edges: Sequence[float],
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
lambda_reg: float = 0.0, # kept for signature compatibility; UNUSED
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if len(bin_edges) < 2:
|
|
|
|
|
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
|
|
|
|
if float(bin_edges[0]) != 0.0:
|
|
|
|
|
raise ValueError("bin_edges[0] must equal 0.0")
|
|
|
|
|
for i in range(1, len(bin_edges)):
|
|
|
|
|
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
|
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
if math.isinf(float(bin_edges[-1])):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"PiecewiseExponentialCIFNLLLoss requires a finite last bin edge (no +inf). "
|
|
|
|
|
"Use a finite truncation horizon for PWE."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.eps = float(eps)
|
|
|
|
|
# unused, kept only for interface compatibility
|
|
|
|
|
self.lambda_reg = float(lambda_reg)
|
|
|
|
|
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"bin_edges",
|
|
|
|
|
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
if reduction not in {"mean", "sum", "none"}:
|
|
|
|
|
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
|
|
|
|
|
|
|
|
|
if logits.ndim != 3:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits must be 3D (M, K, n_bins); got shape={tuple(logits.shape)}")
|
|
|
|
|
if target_events.ndim != 1 or dt.ndim != 1:
|
|
|
|
|
raise ValueError("target_events and dt must be 1D tensors")
|
|
|
|
|
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Batch size mismatch among logits, target_events, dt")
|
|
|
|
|
|
|
|
|
|
if not torch.all(dt > 0):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"dt must be strictly positive (no censoring supported here)")
|
|
|
|
|
|
|
|
|
|
M, K, n_bins = logits.shape
|
|
|
|
|
|
|
|
|
|
if target_events.dtype != torch.long:
|
|
|
|
|
target_events = target_events.to(torch.long)
|
|
|
|
|
if (target_events < 0).any() or (target_events >= K).any():
|
|
|
|
|
raise ValueError(f"target_events must be in [0, {K-1}]")
|
|
|
|
|
|
|
|
|
|
# Prepare bin_edges / bin widths
|
|
|
|
|
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
|
|
|
|
if bin_edges.numel() != n_bins + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"bin_edges length must be n_bins+1={n_bins+1}; got {bin_edges.numel()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dt_bins = (bin_edges[1:] - bin_edges[:-1]
|
|
|
|
|
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
|
|
|
|
|
if not torch.isfinite(dt_bins).all():
|
|
|
|
|
raise ValueError("All bin widths must be finite for PWE.")
|
|
|
|
|
if not (dt_bins > 0).all():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"All bin widths must be strictly positive for PWE.")
|
|
|
|
|
|
|
|
|
|
# Map event time -> bin index k* in {1..n_bins}
|
|
|
|
|
# (same convention as your discrete_time_cif: clamp to [1, n_bins])
|
|
|
|
|
time_bin = torch.bucketize(dt, bin_edges)
|
|
|
|
|
time_bin = torch.clamp(
|
|
|
|
|
time_bin, min=1, max=n_bins).to(torch.long) # (M,)
|
|
|
|
|
k0 = time_bin - 1 # 0..n_bins-1
|
|
|
|
|
|
|
|
|
|
# Nonnegative hazards per cause per bin
|
|
|
|
|
hazards = F.softplus(logits) + self.eps # (M, K, n_bins)
|
|
|
|
|
|
|
|
|
|
# Integrated hazards H_{j,k} = lambda_{j,k} * Δt_k
|
|
|
|
|
H_jk = hazards * dt_bins.view(1, 1, n_bins) # (M, K, n_bins)
|
|
|
|
|
H_k = H_jk.sum(dim=1) # (M, n_bins)
|
|
|
|
|
|
|
|
|
|
# Previous survival term: Σ_{u<k*} H_u
|
|
|
|
|
bins = torch.arange(
|
|
|
|
|
1, n_bins + 1, device=logits.device).unsqueeze(0) # (1, n_bins)
|
|
|
|
|
mask_prev = bins < time_bin.unsqueeze(1) # (M, n_bins)
|
|
|
|
|
loss_prev = (H_k * mask_prev.to(H_k.dtype)).sum(dim=1) # (M,)
|
|
|
|
|
|
|
|
|
|
# Event term at k*: -log p_{k*}(cause)
|
|
|
|
|
m_idx = torch.arange(M, device=logits.device)
|
|
|
|
|
|
|
|
|
|
H_event_total = torch.clamp(H_k[m_idx, k0], min=self.eps) # (M,)
|
|
|
|
|
H_event_cause = torch.clamp(
|
|
|
|
|
H_jk[m_idx, target_events, k0], min=self.eps) # (M,)
|
|
|
|
|
|
|
|
|
|
# log(1 - exp(-H)) stable
|
|
|
|
|
log1mexp = torch.log(-torch.expm1(-H_event_total)) # (M,)
|
|
|
|
|
loss_event = -log1mexp - \
|
|
|
|
|
torch.log(H_event_cause) + torch.log(H_event_total)
|
|
|
|
|
|
|
|
|
|
loss_vec = loss_prev + loss_event # (M,)
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = loss_vec.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll = loss_vec.sum()
|
|
|
|
|
else:
|
|
|
|
|
nll = loss_vec
|
|
|
|
|
|
|
|
|
|
if self.lambda_reg > 0.0 and n_bins >= 3:
|
|
|
|
|
log_h = torch.log(hazards) # (M, K, n_bins)
|
|
|
|
|
d2 = log_h[:, :, 2:] - 2.0 * log_h[:, :, 1:-1] + \
|
|
|
|
|
log_h[:, :, :-2] # (M, K, n_bins-2)
|
|
|
|
|
reg = self.lambda_reg * (d2.pow(2).mean())
|
|
|
|
|
else:
|
|
|
|
|
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
|
|
|
|
|
|
|
|
|
|
return nll, reg
|
2026-01-16 14:55:09 +08:00
|
|
|
|
|
|
|
|
def calculate_cifs(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
taus: torch.Tensor,
|
|
|
|
|
eps: Optional[float] = None,
|
|
|
|
|
return_survival: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""Compute CIFs for piecewise-constant cause-specific hazards.
|
|
|
|
|
|
|
|
|
|
Uses the same binning convention as forward(): taus are mapped to a bin via
|
|
|
|
|
torch.bucketize(taus, bin_edges), clamped to [0, n_bins]. tau<=0 maps to 0.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits: (M, K, n_bins) hazard logits per cause per bin.
|
|
|
|
|
taus: scalar, (T,), (M,), or (M,T) times.
|
|
|
|
|
eps: overrides self.eps for numerical stability.
|
|
|
|
|
return_survival: if True, also return survival S(tau).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
|
|
|
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
|
|
|
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
|
|
|
|
scalar_out = False
|
|
|
|
|
kind = "T"
|
|
|
|
|
if t.ndim == 0:
|
|
|
|
|
t = t.view(1)
|
|
|
|
|
scalar_out = True
|
|
|
|
|
t = t.view(1, 1)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 1:
|
|
|
|
|
if t.shape[0] == batch_size:
|
|
|
|
|
t = t.view(batch_size, 1)
|
|
|
|
|
kind = "per_sample"
|
|
|
|
|
else:
|
|
|
|
|
t = t.view(1, -1)
|
|
|
|
|
kind = "T"
|
|
|
|
|
elif t.ndim == 2:
|
|
|
|
|
if t.shape[0] != batch_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
|
|
|
|
)
|
|
|
|
|
kind = "MT"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
|
|
|
|
return t, kind, scalar_out
|
|
|
|
|
|
|
|
|
|
if logits.ndim != 3:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits must be 3D (M,K,n_bins); got shape={tuple(logits.shape)}")
|
|
|
|
|
|
|
|
|
|
M, K, n_bins = logits.shape
|
|
|
|
|
if self.bin_edges.numel() != n_bins + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"bin_edges length must be n_bins+1={n_bins+1}; got {self.bin_edges.numel()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
used_eps = float(self.eps if eps is None else eps)
|
|
|
|
|
|
|
|
|
|
taus_t, kind, scalar_out = _prepare_taus(
|
|
|
|
|
taus, M, logits.device, logits.dtype)
|
|
|
|
|
taus_t = torch.clamp(taus_t, min=0)
|
|
|
|
|
|
|
|
|
|
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
|
|
|
|
dt_bins = (bin_edges[1:] - bin_edges[:-1]
|
|
|
|
|
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
|
|
|
|
|
|
|
|
|
|
hazards = F.softplus(logits) + used_eps # (M, K, n_bins)
|
|
|
|
|
total_h = hazards.sum(dim=1) # (M, n_bins)
|
|
|
|
|
total_h = torch.clamp(total_h, min=used_eps)
|
|
|
|
|
|
|
|
|
|
# Precompute full-bin CIF increments
|
|
|
|
|
H_total_bin = total_h * dt_bins.view(1, n_bins) # (M, n_bins)
|
|
|
|
|
cum_H_end = torch.cumsum(H_total_bin, dim=1) # (M, n_bins)
|
|
|
|
|
surv_end = torch.exp(-cum_H_end) # (M, n_bins)
|
|
|
|
|
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
|
|
|
|
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
|
|
|
|
|
|
|
|
|
frac = hazards / total_h.unsqueeze(1) # (M, K, n_bins)
|
|
|
|
|
one_minus = 1.0 - \
|
|
|
|
|
torch.exp(-total_h * dt_bins.view(1, n_bins)) # (M, n_bins)
|
|
|
|
|
inc_full = surv_start.unsqueeze(
|
|
|
|
|
1) * frac * one_minus.unsqueeze(1) # (M, K, n_bins)
|
|
|
|
|
cif_full = torch.cumsum(inc_full, dim=2) # (M, K, n_bins)
|
|
|
|
|
|
|
|
|
|
# Map taus -> bin index b in [0..n_bins]
|
|
|
|
|
time_bin = torch.bucketize(taus_t, bin_edges)
|
|
|
|
|
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(
|
|
|
|
|
torch.long) # (...)
|
|
|
|
|
|
|
|
|
|
if kind == "T":
|
|
|
|
|
time_bin = time_bin.expand(M, -1) # (M,T)
|
|
|
|
|
|
|
|
|
|
# Compute within-bin length l and indices
|
|
|
|
|
b = time_bin # (M,T)
|
|
|
|
|
idx_bin0 = torch.clamp(b - 1, min=0) # 0..n_bins-1
|
|
|
|
|
|
|
|
|
|
# Start-of-bin survival for the current bin (for b==0 it's unused)
|
|
|
|
|
S_start_b = surv_start.gather(dim=1, index=idx_bin0) # (M,T)
|
|
|
|
|
|
|
|
|
|
# Length into bin: l = tau - edge[b-1], clamped to [0, dt_bin]
|
|
|
|
|
left_edge = bin_edges.gather(
|
|
|
|
|
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0).to(taus_t.dtype)
|
|
|
|
|
l = taus_t.expand_as(b) - left_edge
|
|
|
|
|
l = torch.clamp(l, min=0)
|
|
|
|
|
width_b = dt_bins.gather(
|
|
|
|
|
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0)
|
|
|
|
|
l = torch.min(l, width_b.to(l.dtype))
|
|
|
|
|
|
|
|
|
|
# CIF up to previous full bins
|
|
|
|
|
# if b<=1 => 0 else cif_full at (b-2)
|
|
|
|
|
prev_idx = torch.clamp(b - 2, min=0)
|
|
|
|
|
cif_before = cif_full.gather(
|
|
|
|
|
dim=2,
|
|
|
|
|
index=prev_idx.unsqueeze(1).expand(-1, K, -1),
|
|
|
|
|
) # (M,K,T)
|
|
|
|
|
if (b <= 1).any():
|
|
|
|
|
cif_before = cif_before.masked_fill((b <= 1).unsqueeze(1), 0.0)
|
|
|
|
|
|
|
|
|
|
# Partial increment in current bin
|
|
|
|
|
total_h_b = total_h.gather(dim=1, index=idx_bin0) # (M,T)
|
|
|
|
|
haz_b = hazards.gather(
|
|
|
|
|
dim=2,
|
|
|
|
|
index=idx_bin0.unsqueeze(1).expand(-1, K, -1),
|
|
|
|
|
) # (M,K,T)
|
|
|
|
|
frac_b = haz_b / total_h_b.unsqueeze(1) # (M,K,T)
|
|
|
|
|
|
|
|
|
|
one_minus_partial = 1.0 - torch.exp(-total_h_b * l) # (M,T)
|
|
|
|
|
inc_partial = S_start_b.unsqueeze(
|
|
|
|
|
1) * frac_b * one_minus_partial.unsqueeze(1) # (M,K,T)
|
|
|
|
|
|
|
|
|
|
cifs = cif_before + inc_partial
|
|
|
|
|
|
|
|
|
|
survival = S_start_b * torch.exp(-total_h_b * l) # (M,T)
|
|
|
|
|
|
|
|
|
|
# Inference-only tail extension beyond the last finite edge.
|
|
|
|
|
# For tau > t_B (t_B = bin_edges[-1]), extend survival and CIFs using
|
|
|
|
|
# constant hazards from the final bin B:
|
|
|
|
|
# S(tau)=S(t_B) * exp(-Λ_B * (tau - t_B))
|
|
|
|
|
# F_k(tau)=F_k(t_B) + S(t_B) * (λ_{k,B}/Λ_B) * (1 - exp(-Λ_B*(tau-t_B)))
|
|
|
|
|
last_edge = bin_edges[-1]
|
|
|
|
|
tau_full = taus_t.expand_as(b) # (M,T)
|
|
|
|
|
tail_mask = tau_full > last_edge
|
|
|
|
|
if tail_mask.any():
|
|
|
|
|
delta = torch.clamp(tau_full - last_edge, min=0) # (M,T)
|
|
|
|
|
|
|
|
|
|
S_B = surv_end[:, -1].unsqueeze(1) # (M,1)
|
|
|
|
|
F_B = cif_full[:, :, -1].unsqueeze(-1) # (M,K,1)
|
|
|
|
|
|
|
|
|
|
lambda_last = hazards[:, :, -1] # (M,K)
|
|
|
|
|
Lambda_last = torch.clamp(
|
|
|
|
|
total_h[:, -1], min=used_eps).unsqueeze(1) # (M,1)
|
|
|
|
|
|
|
|
|
|
exp_tail = torch.exp(-Lambda_last * delta) # (M,T)
|
|
|
|
|
survival_tail = S_B * exp_tail # (M,T)
|
|
|
|
|
cifs_tail = F_B + \
|
|
|
|
|
S_B.unsqueeze(
|
|
|
|
|
1) * (lambda_last / Lambda_last).unsqueeze(-1) * (1.0 - exp_tail).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
survival = torch.where(tail_mask, survival_tail, survival)
|
|
|
|
|
cifs = torch.where(tail_mask.unsqueeze(1), cifs_tail, cifs)
|
|
|
|
|
|
|
|
|
|
# tau mapped to bin 0 => CIF=0, survival=1
|
|
|
|
|
zero_mask = (b == 0)
|
|
|
|
|
if zero_mask.any():
|
|
|
|
|
cifs = cifs.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
|
|
|
|
survival = survival.masked_fill(zero_mask, 1.0)
|
|
|
|
|
|
|
|
|
|
if kind == "per_sample":
|
|
|
|
|
cifs = cifs.squeeze(-1) # (M,K)
|
|
|
|
|
survival = survival.squeeze(-1) # (M,)
|
|
|
|
|
elif scalar_out:
|
|
|
|
|
cifs = cifs.squeeze(-1) # (M,K)
|
|
|
|
|
survival = survival.squeeze(-1) # (M,)
|
|
|
|
|
|
|
|
|
|
return (cifs, survival) if return_survival else cifs
|