Add LogNormalBasisHazardLoss implementation and update training configuration

This commit is contained in:
2026-01-13 15:59:20 +08:00
parent d8b322cbee
commit 1df02d85d7
5 changed files with 431 additions and 15 deletions

232
losses.py
View File

@@ -1,5 +1,5 @@
import math
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@@ -258,3 +258,233 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
return nll, reg
class LogNormalBasisHazardLoss(nn.Module):
"""Event-only competing risks loss using lognormal basis (Gaussian on log-time).
This loss models cause-specific CIF as a mixture of lognormal basis CDFs:
F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma)
Training uses *bin probability mass* (Delta CIF / interval mass). There is
**no censoring**: every sample is an observed event with a valid cause.
Logits interface:
logits: (B, 1 + J*R)
logits[:, 0] -> w0 (survival mass / never-event)
logits[:, 1:] -> flattened (j,r) in row-major order: j then r
index = 1 + j*R + r
Forward interface (must match):
forward(logits, target_events, dt, reduction)
"""
def __init__(
self,
bin_edges: Sequence[float],
centers: Sequence[float],
*,
eps: float = 1e-8,
bandwidth_init: float = 0.5,
bandwidth_min: float = 1e-3,
bandwidth_max: float = 10.0,
lambda_sigma_reg: float = 0.0,
sigma_reg_target: Optional[float] = None,
return_dict: bool = False,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
# allow last edge to be +inf
for i in range(1, len(bin_edges)):
prev = float(bin_edges[i - 1])
cur = float(bin_edges[i])
if math.isinf(prev):
raise ValueError(
"bin_edges cannot have +inf except possibly as the last edge")
if i == len(bin_edges) - 1 and math.isinf(cur):
if not (cur > prev):
raise ValueError("bin_edges must be strictly increasing")
else:
if not (cur > prev):
raise ValueError("bin_edges must be strictly increasing")
if float(bin_edges[0]) < 0.0:
raise ValueError("bin_edges[0] must be >= 0")
if len(centers) < 1:
raise ValueError("centers must have length >= 1")
self.eps = float(eps)
self.bandwidth_min = float(bandwidth_min)
self.bandwidth_max = float(bandwidth_max)
self.lambda_sigma_reg = float(lambda_sigma_reg)
self.sigma_reg_target = None if sigma_reg_target is None else float(
sigma_reg_target)
self.bandwidth_init = float(bandwidth_init)
self.return_dict = bool(return_dict)
self.register_buffer(
"bin_edges",
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"centers",
torch.tensor([float(x) for x in centers], dtype=torch.float32),
persistent=False,
)
if self.bandwidth_init <= 0:
raise ValueError("bandwidth_init must be > 0")
self.log_sigma = nn.Parameter(torch.tensor(
math.log(self.bandwidth_init), dtype=torch.float32))
@staticmethod
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
# Stable normal CDF via erf.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
@staticmethod
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
# Stable normal survival function via erfc.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * torch.erfc(z / math.sqrt(2.0))
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]:
if logits.ndim != 2:
raise ValueError(
f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}")
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError("target_events and dt must be 1D tensors")
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch among logits, target_events, dt")
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
device = logits.device
dtype = logits.dtype
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
centers = self.centers.to(device=device, dtype=dtype)
bsz = logits.shape[0]
r = int(centers.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0:
raise ValueError(
"logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)")
if jr % r != 0:
raise ValueError(
f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}")
j = jr // r
# 1) Stable global weights (includes w0).
w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R)
w0 = w_all[:, 0]
w = w_all[:, 1:].view(bsz, j, r)
# 2) Determine event bin index.
k = int(bin_edges.numel() - 1)
if k < 1:
raise ValueError("bin_edges must define at least one bin")
# v2: dt is always continuous time (float), map to bin via searchsorted.
dt_f = dt.to(device=device, dtype=dtype)
bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1
bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long)
left = bin_edges[bin_idx]
right = bin_edges[bin_idx + 1]
# 3) Stable log(t) clamp.
if float(self.bin_edges[1]) > 0.0:
t_min = float(self.bin_edges[1]) * 1e-6
else:
t_min = 1e-12
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
left_is_zero = left <= 0
# For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly
# when left<=0 (instead of approximating via t_min).
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
is_inf = torch.isinf(right)
# right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value.
right_safe = torch.where(is_inf, left_clamped,
torch.clamp(right, min=t_min_t))
log_right = torch.log(right_safe)
sigma = torch.clamp(self.log_sigma.to(
device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max)
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
z_left = torch.clamp(z_left, -12.0, 12.0)
z_right = torch.clamp(z_right, -12.0, 12.0)
cdf_left = self._normal_cdf(z_left)
# Treat the first-bin left boundary exactly as 0 in CDF.
if left_is_zero.any():
cdf_left = torch.where(
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
cdf_right = self._normal_cdf(z_right)
delta_finite = cdf_right - cdf_left
delta_last = self._normal_sf(z_left)
# If left<=0, SF(left)=1 exactly.
if left_is_zero.any():
delta_last = torch.where(
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
delta_basis = torch.where(
is_inf.unsqueeze(-1), delta_last, delta_finite)
delta_basis = torch.clamp(delta_basis, min=0.0)
# 4) Gather per-sample cause weights and compute event mass.
cause = target_events.to(device=device, dtype=torch.long)
if (cause < 0).any() or (cause >= j).any():
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
b_idx = torch.arange(bsz, device=device)
w_cause = w[b_idx, cause, :] # (B, R)
m = (w_cause * delta_basis).sum(dim=-1)
m = torch.clamp(m, min=self.eps)
nll_vec = -torch.log(m)
if reduction == "mean":
nll: torch.Tensor = nll_vec.mean()
elif reduction == "sum":
nll = nll_vec.sum()
else:
nll = nll_vec
sigma_penalty = torch.zeros((), device=device, dtype=dtype)
if self.lambda_sigma_reg > 0.0:
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
sigma_penalty = (self.log_sigma.to(
device=device, dtype=dtype) - math.log(float(target))) ** 2
reg = sigma_penalty * float(self.lambda_sigma_reg)
if not self.return_dict:
return nll, reg
return {
"nll": nll,
"reg": reg,
"nll_vec": nll_vec,
"sigma": sigma.detach(),
"avg_w0": w0.mean().detach(),
"min_delta_basis": delta_basis.min().detach(),
"mean_m": m.mean().detach(),
"sigma_penalty": sigma_penalty.detach(),
"bin_idx": bin_idx.detach(),
}