2026-01-07 21:32:00 +08:00
|
|
|
import math
|
2026-01-13 15:59:20 +08:00
|
|
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
2026-01-07 21:32:00 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Pair extraction (utility; not used by the losses below)
|
|
|
|
|
# ============================================================
|
|
|
|
|
def get_valid_pairs_and_dt(
|
|
|
|
|
event_seqs: torch.Tensor,
|
|
|
|
|
time_seqs: torch.Tensor,
|
|
|
|
|
n_tech_tokens: int
|
|
|
|
|
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract valid event pairs (prev -> next) and compute dt in years.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
event_seqs (torch.Tensor): Event sequences.
|
|
|
|
|
time_seqs (torch.Tensor): Time sequences.
|
|
|
|
|
n_tech_tokens (int): Number of technical tokens.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
- Assumes strict right-padding.
|
|
|
|
|
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
|
|
|
|
|
- Filters to strictly positive dt.
|
|
|
|
|
"""
|
|
|
|
|
real_mask = event_seqs >= 1
|
|
|
|
|
idx = real_mask.nonzero(as_tuple=False)
|
|
|
|
|
|
|
|
|
|
if idx.size(0) <= 1:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
same_batch = idx[1:, 0] == idx[:-1, 0]
|
|
|
|
|
if not same_batch.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = idx[:-1][same_batch]
|
|
|
|
|
next_idx = idx[1:][same_batch]
|
|
|
|
|
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
|
|
|
|
|
if not valid_target.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = prev_idx[valid_target]
|
|
|
|
|
next_idx = next_idx[valid_target]
|
|
|
|
|
|
|
|
|
|
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
|
|
|
|
|
dt = (time_seqs[b_next, t_next] -
|
|
|
|
|
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
|
|
|
|
|
valid_dt = dt > 0
|
|
|
|
|
if not valid_dt.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
dt = dt[valid_dt]
|
|
|
|
|
b_prev = b_prev[valid_dt]
|
|
|
|
|
t_prev = t_prev[valid_dt]
|
|
|
|
|
b_next = b_next[valid_dt]
|
|
|
|
|
t_next = t_next[valid_dt]
|
|
|
|
|
|
|
|
|
|
return dt, b_prev, t_prev, b_next, t_next
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
|
|
|
|
|
# ============================================================
|
|
|
|
|
class ExponentialNLLLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Competing risks exponential likelihood.
|
|
|
|
|
|
|
|
|
|
The negative log-likelihood is given by:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
eps (float): Small epsilon for numerical stability.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.lambda_reg = lambda_reg
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Forward pass.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits (torch.Tensor): (M, K) tensor of logits.
|
|
|
|
|
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
|
|
|
|
|
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
|
|
|
|
|
reduction (str): 'mean', 'sum', or 'none'.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
|
|
|
|
|
"""
|
|
|
|
|
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
|
|
|
|
hazards = F.softplus(logits) + self.eps # (M,K)
|
|
|
|
|
hazard_event = hazards.gather(
|
|
|
|
|
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
|
|
|
|
|
total_hazard = hazards.sum(dim=1) # (M,)
|
|
|
|
|
log_hazards = torch.log(hazards) # (M, K)
|
|
|
|
|
nll = -torch.log(hazard_event) + total_hazard * dt
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = nll.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll = nll.sum()
|
|
|
|
|
|
|
|
|
|
reg = F.cross_entropy(log_hazards, target_events,
|
|
|
|
|
reduction="mean") * self.lambda_reg
|
|
|
|
|
return nll, reg
|
|
|
|
|
|
|
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
class DiscreteTimeCIFNLLLoss(nn.Module):
|
|
|
|
|
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
|
|
|
|
|
|
|
|
|
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
|
|
|
|
|
channels, where the complement channel (index K) represents survival across bins.
|
|
|
|
|
|
2026-01-09 18:34:01 +08:00
|
|
|
Per-sample likelihood for observed cause k at time bin j:
|
|
|
|
|
p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j)
|
2026-01-09 18:31:38 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
|
|
|
|
|
eps: Unused; kept for interface compatibility / future numerical tweaks.
|
|
|
|
|
lambda_reg: Optional regularization strength.
|
2026-01-08 12:45:31 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
bin_edges: Sequence[float],
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if len(bin_edges) < 2:
|
2026-01-09 18:31:38 +08:00
|
|
|
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
|
|
|
|
if float(bin_edges[0]) != 0.0:
|
|
|
|
|
raise ValueError("bin_edges[0] must equal 0")
|
2026-01-08 12:45:31 +08:00
|
|
|
for i in range(1, len(bin_edges)):
|
2026-01-09 18:31:38 +08:00
|
|
|
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
2026-01-08 12:45:31 +08:00
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
|
|
|
|
|
self.eps = float(eps)
|
|
|
|
|
self.lambda_reg = float(lambda_reg)
|
2026-01-09 18:31:38 +08:00
|
|
|
self.register_buffer(
|
|
|
|
|
"bin_edges",
|
|
|
|
|
torch.tensor(bin_edges, dtype=torch.float32),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
2026-01-08 12:45:31 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2026-01-09 18:31:38 +08:00
|
|
|
if logits.ndim != 3:
|
2026-01-08 12:45:31 +08:00
|
|
|
raise ValueError(
|
2026-01-09 18:31:38 +08:00
|
|
|
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
2026-01-08 12:45:31 +08:00
|
|
|
)
|
2026-01-09 18:31:38 +08:00
|
|
|
if target_events.ndim != 1 or dt.ndim != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
|
|
|
|
)
|
|
|
|
|
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
|
|
|
|
)
|
|
|
|
|
if reduction not in {"mean", "sum", "none"}:
|
|
|
|
|
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if not torch.all(dt > 0):
|
2026-01-09 13:06:43 +08:00
|
|
|
raise ValueError("dt must be strictly positive")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Infer K and n_bins from logits and bin_edges.
|
|
|
|
|
m, k_plus_1, n_bins_plus_1 = logits.shape
|
|
|
|
|
k_comp = k_plus_1 - 1
|
|
|
|
|
if k_comp < 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
n_bins = int(self.bin_edges.numel() - 1)
|
|
|
|
|
if n_bins_plus_1 != n_bins + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
|
|
|
|
)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if target_events.dtype != torch.long:
|
|
|
|
|
target_events = target_events.to(torch.long)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if (target_events < 0).any() or (target_events >= k_comp).any():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
|
|
|
|
|
)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Map continuous dt to discrete bins j in {1..n_bins}.
|
|
|
|
|
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
|
|
|
|
# (M,), may be n_bins+1 if dt > bin_edges[-1]
|
|
|
|
|
time_bin = torch.bucketize(dt, bin_edges)
|
|
|
|
|
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
|
|
|
|
|
torch.long) # ensure valid event bins
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Log-probabilities across causes+complement for each bin.
|
|
|
|
|
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
|
|
|
|
|
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
|
|
|
|
|
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
|
|
|
|
|
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
|
|
|
|
|
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
|
|
|
|
|
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
|
2026-01-08 13:05:53 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Event term at bin j: -log p(k at j)
|
|
|
|
|
m_idx = torch.arange(m, device=logits.device)
|
|
|
|
|
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
loss = loss_prev + loss_event
|
2026-01-08 12:45:31 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = loss.mean()
|
2026-01-08 12:45:31 +08:00
|
|
|
elif reduction == "sum":
|
2026-01-09 18:31:38 +08:00
|
|
|
nll = loss.sum()
|
2026-01-08 12:45:31 +08:00
|
|
|
else:
|
2026-01-09 18:31:38 +08:00
|
|
|
nll = loss
|
|
|
|
|
|
|
|
|
|
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
|
|
|
|
|
if self.lambda_reg > 0.0:
|
|
|
|
|
# Regularize the cause distribution at the event bin using NLL on log-probs.
|
|
|
|
|
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
|
|
|
|
|
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
|
|
|
|
|
logp_at_event_bin = logp_causes.gather(
|
|
|
|
|
dim=2, index=idx).squeeze(2) # (M, K)
|
|
|
|
|
reg = self.lambda_reg * \
|
|
|
|
|
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
2026-01-08 12:45:31 +08:00
|
|
|
|
|
|
|
|
return nll, reg
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
class LogNormalBasisBinnedHazardCIFNLLLoss(nn.Module):
|
|
|
|
|
r"""Route-3: continuous-time lognormal-basis hazards with discrete-time CIF likelihood.
|
2026-01-13 15:59:20 +08:00
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
This implements a cause-specific continuous-time hazard model:
|
2026-01-13 15:59:20 +08:00
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
\lambda_j(t) = \sum_r \alpha_{j,r} b_r(t)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
where b_r(t) is the lognormal PDF basis implied by a Normal on log-time.
|
2026-01-13 15:59:20 +08:00
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
Training objective is IDENTICAL in structure to DiscreteTimeCIFNLLLoss,
|
|
|
|
|
but per-bin categorical probabilities are derived from integrated hazards.
|
|
|
|
|
|
|
|
|
|
Expected logits interface (preferred):
|
|
|
|
|
logits: (B, J*R)
|
|
|
|
|
reshaped to (B, J, R)
|
|
|
|
|
|
|
|
|
|
For convenience/compatibility, also accepts:
|
|
|
|
|
logits: (B, 1+J*R) and ignores the first column.
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
Forward interface (must match):
|
2026-01-13 21:11:38 +08:00
|
|
|
forward(logits, target_events, dt, reduction) -> (nll, reg)
|
2026-01-13 15:59:20 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
bin_edges: Sequence[float],
|
|
|
|
|
centers: Sequence[float],
|
|
|
|
|
*,
|
|
|
|
|
eps: float = 1e-8,
|
2026-01-13 21:11:38 +08:00
|
|
|
alpha_floor: float = 0.0,
|
|
|
|
|
bandwidth_init: float = 0.7,
|
2026-01-13 15:59:20 +08:00
|
|
|
bandwidth_min: float = 1e-3,
|
|
|
|
|
bandwidth_max: float = 10.0,
|
|
|
|
|
lambda_sigma_reg: float = 0.0,
|
|
|
|
|
sigma_reg_target: Optional[float] = None,
|
2026-01-13 21:11:38 +08:00
|
|
|
lambda_reg: float = 0.0,
|
2026-01-13 15:59:20 +08:00
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if len(bin_edges) < 2:
|
2026-01-13 21:11:38 +08:00
|
|
|
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
|
|
|
|
if float(bin_edges[0]) != 0.0:
|
|
|
|
|
raise ValueError("bin_edges[0] must equal 0")
|
2026-01-13 15:59:20 +08:00
|
|
|
# allow last edge to be +inf
|
|
|
|
|
for i in range(1, len(bin_edges)):
|
|
|
|
|
prev = float(bin_edges[i - 1])
|
|
|
|
|
cur = float(bin_edges[i])
|
|
|
|
|
if math.isinf(prev):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"bin_edges cannot have +inf except possibly as the last edge")
|
|
|
|
|
if i == len(bin_edges) - 1 and math.isinf(cur):
|
|
|
|
|
if not (cur > prev):
|
|
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
else:
|
|
|
|
|
if not (cur > prev):
|
|
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
|
|
|
|
|
if len(centers) < 1:
|
|
|
|
|
raise ValueError("centers must have length >= 1")
|
|
|
|
|
|
|
|
|
|
self.eps = float(eps)
|
2026-01-13 21:11:38 +08:00
|
|
|
self.alpha_floor = float(alpha_floor)
|
2026-01-13 15:59:20 +08:00
|
|
|
self.bandwidth_min = float(bandwidth_min)
|
|
|
|
|
self.bandwidth_max = float(bandwidth_max)
|
2026-01-13 21:11:38 +08:00
|
|
|
self.bandwidth_init = float(bandwidth_init)
|
2026-01-13 15:59:20 +08:00
|
|
|
self.lambda_sigma_reg = float(lambda_sigma_reg)
|
|
|
|
|
self.sigma_reg_target = None if sigma_reg_target is None else float(
|
|
|
|
|
sigma_reg_target)
|
2026-01-13 21:11:38 +08:00
|
|
|
self.lambda_reg = float(lambda_reg)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"bin_edges",
|
|
|
|
|
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"centers",
|
|
|
|
|
torch.tensor([float(x) for x in centers], dtype=torch.float32),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.bandwidth_init <= 0:
|
|
|
|
|
raise ValueError("bandwidth_init must be > 0")
|
|
|
|
|
self.log_sigma = nn.Parameter(torch.tensor(
|
|
|
|
|
math.log(self.bandwidth_init), dtype=torch.float32))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
z = torch.clamp(z, -12.0, 12.0)
|
|
|
|
|
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
z = torch.clamp(z, -12.0, 12.0)
|
|
|
|
|
return 0.5 * torch.erfc(z / math.sqrt(2.0))
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
def _compute_delta_basis_all_bins(
|
2026-01-13 15:59:20 +08:00
|
|
|
self,
|
2026-01-13 21:11:38 +08:00
|
|
|
*,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Compute ΔB[k,r] for bins k=1..n_bins (shape: (n_bins, R))."""
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
|
|
|
|
|
centers = self.centers.to(device=device, dtype=dtype)
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
n_bins = int(bin_edges.numel() - 1)
|
|
|
|
|
if n_bins < 1:
|
2026-01-13 15:59:20 +08:00
|
|
|
raise ValueError("bin_edges must define at least one bin")
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
left = bin_edges[:-1] # (n_bins,)
|
|
|
|
|
right = bin_edges[1:] # (n_bins,)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
if float(self.bin_edges[1]) > 0.0:
|
|
|
|
|
t_min = float(self.bin_edges[1]) * 1e-6
|
|
|
|
|
else:
|
|
|
|
|
t_min = 1e-12
|
|
|
|
|
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
left_is_zero = left <= 0
|
|
|
|
|
left_clamped = torch.clamp(left, min=t_min_t)
|
|
|
|
|
log_left = torch.log(left_clamped)
|
2026-01-13 21:11:38 +08:00
|
|
|
|
2026-01-13 15:59:20 +08:00
|
|
|
is_inf = torch.isinf(right)
|
2026-01-13 21:11:38 +08:00
|
|
|
right_safe = torch.where(
|
|
|
|
|
is_inf, left_clamped, torch.clamp(right, min=t_min_t))
|
2026-01-13 15:59:20 +08:00
|
|
|
log_right = torch.log(right_safe)
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
sigma = torch.clamp(
|
|
|
|
|
self.log_sigma.to(device=device, dtype=dtype).exp(),
|
|
|
|
|
self.bandwidth_min,
|
|
|
|
|
self.bandwidth_max,
|
|
|
|
|
)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
|
|
|
|
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
|
|
|
|
z_left = torch.clamp(z_left, -12.0, 12.0)
|
|
|
|
|
z_right = torch.clamp(z_right, -12.0, 12.0)
|
|
|
|
|
|
|
|
|
|
cdf_left = self._normal_cdf(z_left)
|
|
|
|
|
if left_is_zero.any():
|
|
|
|
|
cdf_left = torch.where(
|
|
|
|
|
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
|
2026-01-13 21:11:38 +08:00
|
|
|
|
2026-01-13 15:59:20 +08:00
|
|
|
cdf_right = self._normal_cdf(z_right)
|
|
|
|
|
delta_finite = cdf_right - cdf_left
|
2026-01-13 21:11:38 +08:00
|
|
|
|
|
|
|
|
# Last bin: ΔB = 1 - CDF(left) = SF(left), computed via erfc for stability.
|
2026-01-13 15:59:20 +08:00
|
|
|
delta_last = self._normal_sf(z_left)
|
|
|
|
|
if left_is_zero.any():
|
|
|
|
|
delta_last = torch.where(
|
|
|
|
|
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
|
2026-01-13 21:11:38 +08:00
|
|
|
|
2026-01-13 15:59:20 +08:00
|
|
|
delta_basis = torch.where(
|
|
|
|
|
is_inf.unsqueeze(-1), delta_last, delta_finite)
|
|
|
|
|
delta_basis = torch.clamp(delta_basis, min=0.0)
|
2026-01-13 21:11:38 +08:00
|
|
|
return delta_basis
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
if logits.ndim not in {2, 3}:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits must be 2D (B, J*R) (or (B, 1+J*R)) or 3D (B, J, R); got {tuple(logits.shape)}"
|
|
|
|
|
)
|
|
|
|
|
if target_events.ndim != 1 or dt.ndim != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
|
|
|
|
)
|
|
|
|
|
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
|
|
|
|
)
|
|
|
|
|
if reduction not in {"mean", "sum", "none"}:
|
|
|
|
|
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
|
|
|
|
|
|
|
|
|
if not torch.all(dt > 0):
|
|
|
|
|
raise ValueError("dt must be strictly positive")
|
|
|
|
|
|
|
|
|
|
device = logits.device
|
|
|
|
|
dtype = logits.dtype
|
|
|
|
|
|
|
|
|
|
centers = self.centers.to(device=device, dtype=dtype)
|
|
|
|
|
r = int(centers.numel())
|
|
|
|
|
if r < 1:
|
|
|
|
|
raise ValueError("centers must have length >= 1")
|
|
|
|
|
|
|
|
|
|
if logits.ndim == 3:
|
|
|
|
|
if logits.shape[2] != r:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
|
|
|
|
|
)
|
|
|
|
|
j = int(logits.shape[1])
|
|
|
|
|
if j < 1:
|
|
|
|
|
raise ValueError("Inferred number of causes J must be >= 1")
|
|
|
|
|
alpha = F.softplus(logits) + self.alpha_floor # (B, J, R)
|
|
|
|
|
else:
|
|
|
|
|
d = int(logits.shape[1])
|
|
|
|
|
offset = 0
|
|
|
|
|
if d % r == 0:
|
|
|
|
|
jr = d
|
|
|
|
|
elif (d - 1) % r == 0:
|
|
|
|
|
offset = 1
|
|
|
|
|
jr = d - 1
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
j = jr // r
|
|
|
|
|
if j < 1:
|
|
|
|
|
raise ValueError("Inferred number of causes J must be >= 1")
|
|
|
|
|
|
|
|
|
|
logits_used = logits[:, offset:]
|
|
|
|
|
alpha = F.softplus(logits_used).view(-1, j, r) + \
|
|
|
|
|
self.alpha_floor # (B, J, R)
|
|
|
|
|
|
|
|
|
|
delta_basis = self._compute_delta_basis_all_bins(
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
) # (n_bins, R)
|
|
|
|
|
n_bins = int(delta_basis.shape[0])
|
|
|
|
|
|
|
|
|
|
# H_{j,k} = sum_r alpha_{j,r} * ΔB_{k,r}
|
|
|
|
|
h_jk = torch.einsum("mjr,kr->mjk", alpha,
|
|
|
|
|
delta_basis) # (B, J, n_bins)
|
|
|
|
|
h_k = h_jk.sum(dim=1) # (B, n_bins)
|
|
|
|
|
|
|
|
|
|
# Map continuous dt to discrete event bin index k* in {1..n_bins}.
|
|
|
|
|
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
|
|
|
|
time_bin = torch.bucketize(dt, bin_edges)
|
|
|
|
|
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(torch.long)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
cause = target_events.to(device=device, dtype=torch.long)
|
|
|
|
|
if (cause < 0).any() or (cause >= j).any():
|
|
|
|
|
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
|
|
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
# Previous survival term: sum_{u<k*} H_u
|
|
|
|
|
bins = torch.arange(
|
|
|
|
|
1, n_bins + 1, device=device).unsqueeze(0) # (1, n_bins)
|
|
|
|
|
mask_prev = bins < time_bin.unsqueeze(1) # (B, n_bins)
|
|
|
|
|
loss_prev = (h_k * mask_prev.to(h_k.dtype)).sum(dim=1) # (B,)
|
|
|
|
|
|
|
|
|
|
# Event term at k*: -log p_{k*}(cause)
|
|
|
|
|
b_idx = torch.arange(target_events.shape[0], device=device)
|
|
|
|
|
k0 = time_bin - 1 # (B,) index into 0..n_bins-1
|
|
|
|
|
h_event_total = h_k[b_idx, k0]
|
|
|
|
|
h_event_total = torch.clamp(h_event_total, min=self.eps)
|
2026-01-13 15:59:20 +08:00
|
|
|
|
2026-01-13 21:11:38 +08:00
|
|
|
h_event_cause = h_jk[b_idx, cause, k0]
|
|
|
|
|
h_event_cause = torch.clamp(h_event_cause, min=self.eps)
|
|
|
|
|
|
|
|
|
|
# log(1 - exp(-H)) stably
|
|
|
|
|
log1mexp = torch.log(-torch.expm1(-h_event_total))
|
|
|
|
|
loss_event = -log1mexp - \
|
|
|
|
|
torch.log(h_event_cause) + torch.log(h_event_total)
|
|
|
|
|
|
|
|
|
|
loss = loss_prev + loss_event
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
if reduction == "mean":
|
2026-01-13 21:11:38 +08:00
|
|
|
nll = loss.mean()
|
2026-01-13 15:59:20 +08:00
|
|
|
elif reduction == "sum":
|
2026-01-13 21:11:38 +08:00
|
|
|
nll = loss.sum()
|
2026-01-13 15:59:20 +08:00
|
|
|
else:
|
2026-01-13 21:11:38 +08:00
|
|
|
nll = loss
|
|
|
|
|
|
|
|
|
|
reg = torch.zeros((), device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
if self.lambda_reg > 0.0:
|
|
|
|
|
# Regularize the within-bin cause competition via NLL on log ratios.
|
|
|
|
|
# ratio_j = H_{j,k*} / H_{k*}
|
|
|
|
|
h_event_all = h_jk[b_idx, :, k0] # (B, J)
|
|
|
|
|
denom = torch.clamp(h_event_total, min=self.eps).unsqueeze(1)
|
|
|
|
|
ratio = torch.clamp(h_event_all / denom, min=self.eps)
|
|
|
|
|
log_ratio = torch.log(ratio)
|
|
|
|
|
reg = reg + self.lambda_reg * F.nll_loss(
|
|
|
|
|
log_ratio, cause, reduction="mean")
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
if self.lambda_sigma_reg > 0.0:
|
|
|
|
|
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
|
|
|
|
|
sigma_penalty = (self.log_sigma.to(
|
|
|
|
|
device=device, dtype=dtype) - math.log(float(target))) ** 2
|
2026-01-13 21:11:38 +08:00
|
|
|
reg = reg + sigma_penalty * self.lambda_sigma_reg
|
|
|
|
|
|
|
|
|
|
return nll, reg
|