Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.

This commit is contained in:
2026-01-09 18:31:38 +08:00
parent 880fd53a4b
commit 209dde2299
3 changed files with 172 additions and 349 deletions

222
losses.py
View File

@@ -132,9 +132,19 @@ class ExponentialNLLLoss(nn.Module):
return nll, reg
class PiecewiseExponentialLoss(nn.Module):
"""
Piecewise-constant competing risks exponential likelihood.
class DiscreteTimeCIFNLLLoss(nn.Module):
"""Direct discrete-time CIF negative log-likelihood (no censoring).
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
channels, where the complement channel (index K) represents survival across bins.
Per-sample likelihood for observed cause k at time bin j:
p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j)
Args:
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
eps: Unused; kept for interface compatibility / future numerical tweaks.
lambda_reg: Optional regularization strength.
"""
def __init__(
@@ -146,18 +156,20 @@ class PiecewiseExponentialLoss(nn.Module):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
if bin_edges[0] != 0:
raise ValueError("bin_edges must start at 0")
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
if float(bin_edges[0]) != 0.0:
raise ValueError("bin_edges[0] must equal 0")
for i in range(1, len(bin_edges)):
if not (bin_edges[i] > bin_edges[i - 1]):
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps)
self.lambda_reg = float(lambda_reg)
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
self.register_buffer("bin_edges", edges, persistent=False)
self.register_buffer(
"bin_edges",
torch.tensor(bin_edges, dtype=torch.float32),
persistent=False,
)
def forward(
self,
@@ -166,145 +178,83 @@ class PiecewiseExponentialLoss(nn.Module):
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.dim() != 3:
raise ValueError("logits must have shape (M, K, B)")
M, K, B = logits.shape
if self.bin_edges.numel() != B + 1:
if logits.ndim != 3:
raise ValueError(
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
)
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError(
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
)
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
)
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
device = logits.device
dt = dt.to(device=device, dtype=torch.float32)
target_events = target_events.to(device=device)
if not torch.all(dt > 0):
raise ValueError("dt must be strictly positive")
# Infer K and n_bins from logits and bin_edges.
m, k_plus_1, n_bins_plus_1 = logits.shape
k_comp = k_plus_1 - 1
if k_comp < 1:
raise ValueError(
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
n_bins = int(self.bin_edges.numel() - 1)
if n_bins_plus_1 != n_bins + 1:
raise ValueError(
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
)
if target_events.dtype != torch.long:
target_events = target_events.to(dtype=torch.long)
if target_events.min().item() < 0 or target_events.max().item() >= K:
raise ValueError("target_events must be in [0, K)")
target_events = target_events.to(torch.long)
# Hazards: (M, K, B)
hazards = F.softplus(logits) + self.eps
total_hazard = hazards.sum(dim=1) # (M, B)
if (target_events < 0).any() or (target_events >= k_comp).any():
raise ValueError(
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
)
edges = self.bin_edges.to(device=device, dtype=dt.dtype)
widths = edges[1:] - edges[:-1] # (B,)
# Map continuous dt to discrete bins j in {1..n_bins}.
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
# (M,), may be n_bins+1 if dt > bin_edges[-1]
time_bin = torch.bucketize(dt, bin_edges)
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
torch.long) # ensure valid event bins
if dt.min().item() <= 0:
raise ValueError("dt must be strictly positive")
if dt.max().item() > edges[-1].item():
raise ValueError("dt must be <= last bin edge")
# Log-probabilities across causes+complement for each bin.
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
# Bin index b* in [0, B-1].
b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,)
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
# 1. Hazard at event (M,)
# gather needs matching dims.
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
# Alternative: hazards[m, k, b]
ar = torch.arange(M, device=device)
hazard_event = hazards[ar, target_events, b_star] # (M,)
hazard_event = torch.clamp(hazard_event, min=self.eps)
# Event term at bin j: -log p(k at j)
m_idx = torch.arange(m, device=logits.device)
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
# 2. Integral part
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
# Full bins accumulation
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
cum = weighted.cumsum(dim=1) # (M, B)
full_bins_int = torch.zeros_like(dt)
# We process 'has_full' logic generally.
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
has_full = b_star > 0
# NOTE: Even without protection, we need valid indices for gather.
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
safe_indices = (b_star - 1).clamp(min=0)
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
# Partial bin accumulation
edge_left = edges[b_star] # (M,)
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
partial = partial_hazard * (dt - edge_left)
integral = full_bins_int + partial
# Final NLL
nll = -torch.log(hazard_event) + integral
# Reduction
if reduction == "none":
nll_out = nll
elif reduction == "sum":
nll_out = nll.sum()
elif reduction == "mean":
nll_out = nll.mean()
else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = logits.new_zeros(())
if self.lambda_reg != 0.0:
reg = reg + (self.lambda_reg * logits.pow(2).mean())
return nll_out, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t.
"""
def __init__(
self,
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(self, logits, target_events, dt, reduction="mean"):
if logits.dim() != 3 or logits.size(-1) != 2:
raise ValueError("logits must have shape (M, K, 2)")
M, K, _ = logits.shape
device = logits.device
dt = dt.to(device=device, dtype=torch.float32)
if dt.min().item() <= 0:
raise ValueError("dt must be strictly positive")
target_events = target_events.to(device=device)
target_events = target_events.to(dtype=torch.long)
if target_events.min().item() < 0 or target_events.max().item() >= K:
raise ValueError("target_events must be in [0, K)")
shapes = F.softplus(logits[..., 0]) + self.eps
scales = F.softplus(logits[..., 1]) + self.eps
t_mat = dt.unsqueeze(1) # (M,1)
cum_hazard = scales * torch.pow(t_mat, shapes)
hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0)
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
hazard_event = torch.clamp(hazard_event, min=self.eps)
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
loss = loss_prev + loss_event
if reduction == "mean":
nll = nll.mean()
nll = loss.mean()
elif reduction == "sum":
nll = nll.sum()
elif reduction != "none":
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
nll = loss.sum()
else:
nll = loss
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
if self.lambda_reg > 0.0:
# Regularize the cause distribution at the event bin using NLL on log-probs.
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
logp_at_event_bin = logp_causes.gather(
dim=2, index=idx).squeeze(2) # (M, K)
reg = self.lambda_reg * \
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
reg = shapes.new_zeros(())
if self.lambda_reg > 0:
reg = self.lambda_reg * (
(torch.log(scales + self.eps) ** 2).mean() +
(torch.log(shapes + self.eps) ** 2).mean()
)
return nll, reg