Refactor LogNormalBasisHazardLoss to LogNormalBasisBinnedHazardCIFNLLLoss and update related configurations
This commit is contained in:
264
losses.py
264
losses.py
@@ -260,24 +260,27 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
return nll, reg
|
||||
|
||||
|
||||
class LogNormalBasisHazardLoss(nn.Module):
|
||||
"""Event-only competing risks loss using lognormal basis (Gaussian on log-time).
|
||||
class LogNormalBasisBinnedHazardCIFNLLLoss(nn.Module):
|
||||
r"""Route-3: continuous-time lognormal-basis hazards with discrete-time CIF likelihood.
|
||||
|
||||
This loss models cause-specific CIF as a mixture of lognormal basis CDFs:
|
||||
This implements a cause-specific continuous-time hazard model:
|
||||
|
||||
F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma)
|
||||
\lambda_j(t) = \sum_r \alpha_{j,r} b_r(t)
|
||||
|
||||
Training uses *bin probability mass* (Delta CIF / interval mass). There is
|
||||
**no censoring**: every sample is an observed event with a valid cause.
|
||||
where b_r(t) is the lognormal PDF basis implied by a Normal on log-time.
|
||||
|
||||
Logits interface:
|
||||
logits: (B, 1 + J*R)
|
||||
logits[:, 0] -> w0 (survival mass / never-event)
|
||||
logits[:, 1:] -> flattened (j,r) in row-major order: j then r
|
||||
index = 1 + j*R + r
|
||||
Training objective is IDENTICAL in structure to DiscreteTimeCIFNLLLoss,
|
||||
but per-bin categorical probabilities are derived from integrated hazards.
|
||||
|
||||
Expected logits interface (preferred):
|
||||
logits: (B, J*R)
|
||||
reshaped to (B, J, R)
|
||||
|
||||
For convenience/compatibility, also accepts:
|
||||
logits: (B, 1+J*R) and ignores the first column.
|
||||
|
||||
Forward interface (must match):
|
||||
forward(logits, target_events, dt, reduction)
|
||||
forward(logits, target_events, dt, reduction) -> (nll, reg)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -286,17 +289,20 @@ class LogNormalBasisHazardLoss(nn.Module):
|
||||
centers: Sequence[float],
|
||||
*,
|
||||
eps: float = 1e-8,
|
||||
bandwidth_init: float = 0.5,
|
||||
alpha_floor: float = 0.0,
|
||||
bandwidth_init: float = 0.7,
|
||||
bandwidth_min: float = 1e-3,
|
||||
bandwidth_max: float = 10.0,
|
||||
lambda_sigma_reg: float = 0.0,
|
||||
sigma_reg_target: Optional[float] = None,
|
||||
return_dict: bool = False,
|
||||
lambda_reg: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if len(bin_edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2")
|
||||
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
||||
if float(bin_edges[0]) != 0.0:
|
||||
raise ValueError("bin_edges[0] must equal 0")
|
||||
# allow last edge to be +inf
|
||||
for i in range(1, len(bin_edges)):
|
||||
prev = float(bin_edges[i - 1])
|
||||
@@ -310,20 +316,19 @@ class LogNormalBasisHazardLoss(nn.Module):
|
||||
else:
|
||||
if not (cur > prev):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
if float(bin_edges[0]) < 0.0:
|
||||
raise ValueError("bin_edges[0] must be >= 0")
|
||||
|
||||
if len(centers) < 1:
|
||||
raise ValueError("centers must have length >= 1")
|
||||
|
||||
self.eps = float(eps)
|
||||
self.alpha_floor = float(alpha_floor)
|
||||
self.bandwidth_min = float(bandwidth_min)
|
||||
self.bandwidth_max = float(bandwidth_max)
|
||||
self.bandwidth_init = float(bandwidth_init)
|
||||
self.lambda_sigma_reg = float(lambda_sigma_reg)
|
||||
self.sigma_reg_target = None if sigma_reg_target is None else float(
|
||||
sigma_reg_target)
|
||||
self.bandwidth_init = float(bandwidth_init)
|
||||
self.return_dict = bool(return_dict)
|
||||
self.lambda_reg = float(lambda_reg)
|
||||
|
||||
self.register_buffer(
|
||||
"bin_edges",
|
||||
@@ -343,69 +348,32 @@ class LogNormalBasisHazardLoss(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
|
||||
# Stable normal CDF via erf.
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||||
|
||||
@staticmethod
|
||||
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
|
||||
# Stable normal survival function via erfc.
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * torch.erfc(z / math.sqrt(2.0))
|
||||
|
||||
def forward(
|
||||
def _compute_delta_basis_all_bins(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]:
|
||||
if logits.ndim != 2:
|
||||
raise ValueError(
|
||||
f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}")
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError("target_events and dt must be 1D tensors")
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch among logits, target_events, dt")
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
device = logits.device
|
||||
dtype = logits.dtype
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Compute ΔB[k,r] for bins k=1..n_bins (shape: (n_bins, R))."""
|
||||
|
||||
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
|
||||
centers = self.centers.to(device=device, dtype=dtype)
|
||||
bsz = logits.shape[0]
|
||||
r = int(centers.numel())
|
||||
jr = int(logits.shape[1] - 1)
|
||||
if jr <= 0:
|
||||
raise ValueError(
|
||||
"logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)")
|
||||
if jr % r != 0:
|
||||
raise ValueError(
|
||||
f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}")
|
||||
j = jr // r
|
||||
|
||||
# 1) Stable global weights (includes w0).
|
||||
w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R)
|
||||
w0 = w_all[:, 0]
|
||||
w = w_all[:, 1:].view(bsz, j, r)
|
||||
|
||||
# 2) Determine event bin index.
|
||||
k = int(bin_edges.numel() - 1)
|
||||
if k < 1:
|
||||
n_bins = int(bin_edges.numel() - 1)
|
||||
if n_bins < 1:
|
||||
raise ValueError("bin_edges must define at least one bin")
|
||||
|
||||
# v2: dt is always continuous time (float), map to bin via searchsorted.
|
||||
dt_f = dt.to(device=device, dtype=dtype)
|
||||
bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1
|
||||
bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long)
|
||||
left = bin_edges[:-1] # (n_bins,)
|
||||
right = bin_edges[1:] # (n_bins,)
|
||||
|
||||
left = bin_edges[bin_idx]
|
||||
right = bin_edges[bin_idx + 1]
|
||||
|
||||
# 3) Stable log(t) clamp.
|
||||
if float(self.bin_edges[1]) > 0.0:
|
||||
t_min = float(self.bin_edges[1]) * 1e-6
|
||||
else:
|
||||
@@ -413,19 +381,19 @@ class LogNormalBasisHazardLoss(nn.Module):
|
||||
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
|
||||
|
||||
left_is_zero = left <= 0
|
||||
|
||||
# For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly
|
||||
# when left<=0 (instead of approximating via t_min).
|
||||
left_clamped = torch.clamp(left, min=t_min_t)
|
||||
log_left = torch.log(left_clamped)
|
||||
|
||||
is_inf = torch.isinf(right)
|
||||
# right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value.
|
||||
right_safe = torch.where(is_inf, left_clamped,
|
||||
torch.clamp(right, min=t_min_t))
|
||||
right_safe = torch.where(
|
||||
is_inf, left_clamped, torch.clamp(right, min=t_min_t))
|
||||
log_right = torch.log(right_safe)
|
||||
|
||||
sigma = torch.clamp(self.log_sigma.to(
|
||||
device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max)
|
||||
sigma = torch.clamp(
|
||||
self.log_sigma.to(device=device, dtype=dtype).exp(),
|
||||
self.bandwidth_min,
|
||||
self.bandwidth_max,
|
||||
)
|
||||
|
||||
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
||||
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
||||
@@ -433,58 +401,152 @@ class LogNormalBasisHazardLoss(nn.Module):
|
||||
z_right = torch.clamp(z_right, -12.0, 12.0)
|
||||
|
||||
cdf_left = self._normal_cdf(z_left)
|
||||
# Treat the first-bin left boundary exactly as 0 in CDF.
|
||||
if left_is_zero.any():
|
||||
cdf_left = torch.where(
|
||||
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
|
||||
|
||||
cdf_right = self._normal_cdf(z_right)
|
||||
delta_finite = cdf_right - cdf_left
|
||||
|
||||
# Last bin: ΔB = 1 - CDF(left) = SF(left), computed via erfc for stability.
|
||||
delta_last = self._normal_sf(z_left)
|
||||
# If left<=0, SF(left)=1 exactly.
|
||||
if left_is_zero.any():
|
||||
delta_last = torch.where(
|
||||
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
|
||||
|
||||
delta_basis = torch.where(
|
||||
is_inf.unsqueeze(-1), delta_last, delta_finite)
|
||||
delta_basis = torch.clamp(delta_basis, min=0.0)
|
||||
return delta_basis
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if logits.ndim not in {2, 3}:
|
||||
raise ValueError(
|
||||
f"logits must be 2D (B, J*R) (or (B, 1+J*R)) or 3D (B, J, R); got {tuple(logits.shape)}"
|
||||
)
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError(
|
||||
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
||||
)
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
||||
)
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
if not torch.all(dt > 0):
|
||||
raise ValueError("dt must be strictly positive")
|
||||
|
||||
device = logits.device
|
||||
dtype = logits.dtype
|
||||
|
||||
centers = self.centers.to(device=device, dtype=dtype)
|
||||
r = int(centers.numel())
|
||||
if r < 1:
|
||||
raise ValueError("centers must have length >= 1")
|
||||
|
||||
if logits.ndim == 3:
|
||||
if logits.shape[2] != r:
|
||||
raise ValueError(
|
||||
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
|
||||
)
|
||||
j = int(logits.shape[1])
|
||||
if j < 1:
|
||||
raise ValueError("Inferred number of causes J must be >= 1")
|
||||
alpha = F.softplus(logits) + self.alpha_floor # (B, J, R)
|
||||
else:
|
||||
d = int(logits.shape[1])
|
||||
offset = 0
|
||||
if d % r == 0:
|
||||
jr = d
|
||||
elif (d - 1) % r == 0:
|
||||
offset = 1
|
||||
jr = d - 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}"
|
||||
)
|
||||
|
||||
j = jr // r
|
||||
if j < 1:
|
||||
raise ValueError("Inferred number of causes J must be >= 1")
|
||||
|
||||
logits_used = logits[:, offset:]
|
||||
alpha = F.softplus(logits_used).view(-1, j, r) + \
|
||||
self.alpha_floor # (B, J, R)
|
||||
|
||||
delta_basis = self._compute_delta_basis_all_bins(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
) # (n_bins, R)
|
||||
n_bins = int(delta_basis.shape[0])
|
||||
|
||||
# H_{j,k} = sum_r alpha_{j,r} * ΔB_{k,r}
|
||||
h_jk = torch.einsum("mjr,kr->mjk", alpha,
|
||||
delta_basis) # (B, J, n_bins)
|
||||
h_k = h_jk.sum(dim=1) # (B, n_bins)
|
||||
|
||||
# Map continuous dt to discrete event bin index k* in {1..n_bins}.
|
||||
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
||||
time_bin = torch.bucketize(dt, bin_edges)
|
||||
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(torch.long)
|
||||
|
||||
# 4) Gather per-sample cause weights and compute event mass.
|
||||
cause = target_events.to(device=device, dtype=torch.long)
|
||||
if (cause < 0).any() or (cause >= j).any():
|
||||
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
|
||||
|
||||
b_idx = torch.arange(bsz, device=device)
|
||||
w_cause = w[b_idx, cause, :] # (B, R)
|
||||
# Previous survival term: sum_{u<k*} H_u
|
||||
bins = torch.arange(
|
||||
1, n_bins + 1, device=device).unsqueeze(0) # (1, n_bins)
|
||||
mask_prev = bins < time_bin.unsqueeze(1) # (B, n_bins)
|
||||
loss_prev = (h_k * mask_prev.to(h_k.dtype)).sum(dim=1) # (B,)
|
||||
|
||||
m = (w_cause * delta_basis).sum(dim=-1)
|
||||
m = torch.clamp(m, min=self.eps)
|
||||
nll_vec = -torch.log(m)
|
||||
# Event term at k*: -log p_{k*}(cause)
|
||||
b_idx = torch.arange(target_events.shape[0], device=device)
|
||||
k0 = time_bin - 1 # (B,) index into 0..n_bins-1
|
||||
h_event_total = h_k[b_idx, k0]
|
||||
h_event_total = torch.clamp(h_event_total, min=self.eps)
|
||||
|
||||
h_event_cause = h_jk[b_idx, cause, k0]
|
||||
h_event_cause = torch.clamp(h_event_cause, min=self.eps)
|
||||
|
||||
# log(1 - exp(-H)) stably
|
||||
log1mexp = torch.log(-torch.expm1(-h_event_total))
|
||||
loss_event = -log1mexp - \
|
||||
torch.log(h_event_cause) + torch.log(h_event_total)
|
||||
|
||||
loss = loss_prev + loss_event
|
||||
|
||||
if reduction == "mean":
|
||||
nll: torch.Tensor = nll_vec.mean()
|
||||
nll = loss.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll_vec.sum()
|
||||
nll = loss.sum()
|
||||
else:
|
||||
nll = nll_vec
|
||||
nll = loss
|
||||
|
||||
reg = torch.zeros((), device=device, dtype=dtype)
|
||||
|
||||
if self.lambda_reg > 0.0:
|
||||
# Regularize the within-bin cause competition via NLL on log ratios.
|
||||
# ratio_j = H_{j,k*} / H_{k*}
|
||||
h_event_all = h_jk[b_idx, :, k0] # (B, J)
|
||||
denom = torch.clamp(h_event_total, min=self.eps).unsqueeze(1)
|
||||
ratio = torch.clamp(h_event_all / denom, min=self.eps)
|
||||
log_ratio = torch.log(ratio)
|
||||
reg = reg + self.lambda_reg * F.nll_loss(
|
||||
log_ratio, cause, reduction="mean")
|
||||
|
||||
sigma_penalty = torch.zeros((), device=device, dtype=dtype)
|
||||
if self.lambda_sigma_reg > 0.0:
|
||||
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
|
||||
sigma_penalty = (self.log_sigma.to(
|
||||
device=device, dtype=dtype) - math.log(float(target))) ** 2
|
||||
reg = sigma_penalty * float(self.lambda_sigma_reg)
|
||||
reg = reg + sigma_penalty * self.lambda_sigma_reg
|
||||
|
||||
if not self.return_dict:
|
||||
return nll, reg
|
||||
|
||||
return {
|
||||
"nll": nll,
|
||||
"reg": reg,
|
||||
"nll_vec": nll_vec,
|
||||
"sigma": sigma.detach(),
|
||||
"avg_w0": w0.mean().detach(),
|
||||
"min_delta_basis": delta_basis.min().detach(),
|
||||
"mean_m": m.mean().detach(),
|
||||
"sigma_penalty": sigma_penalty.detach(),
|
||||
"bin_idx": bin_idx.detach(),
|
||||
}
|
||||
return nll, reg
|
||||
|
||||
Reference in New Issue
Block a user