Add PiecewiseExponentialLoss class and update TrainConfig for new loss type

This commit is contained in:
2026-01-08 12:45:31 +08:00
parent 7c36f7a007
commit 06a01d2893
2 changed files with 144 additions and 10 deletions

131
losses.py
View File

@@ -132,6 +132,135 @@ class ExponentialNLLLoss(nn.Module):
return nll, reg
class PiecewiseExponentialLoss(nn.Module):
"""Piecewise-constant competing risks exponential likelihood.
Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0).
Within each bin b, hazards are constant and parameterized as:
hazards = softplus(logits) + eps with logits shape (M, K, B)
For each sample i, dt is bucketized to bin b* and the NLL is:
nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du
The integral is computed in closed form by summing full bins plus the partial bin b*.
"""
def __init__(
self,
bin_edges: Sequence[float],
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
if bin_edges[0] != 0:
raise ValueError("bin_edges must start at 0")
for i in range(1, len(bin_edges)):
if not (bin_edges[i] > bin_edges[i - 1]):
raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps)
self.lambda_reg = float(lambda_reg)
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
self.register_buffer("bin_edges", edges, persistent=False)
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.dim() != 3:
raise ValueError("logits must have shape (M, K, B)")
M, K, B = logits.shape
if self.bin_edges.numel() != B + 1:
raise ValueError(
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
)
device = logits.device
dt = dt.to(device=device)
target_events = target_events.to(device=device)
# Build a per-sample finite mask to avoid NaN/Inf propagation.
logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1)
dt_finite = torch.isfinite(dt)
target_finite = torch.isfinite(target_events)
finite_mask = logits_finite & dt_finite & target_finite
nll_full = logits.new_zeros((M,))
if not finite_mask.any():
nll_out = nll_full if reduction == "none" else logits.new_zeros(())
reg_out = logits.new_zeros(())
return nll_out, reg_out
idx = finite_mask.nonzero(as_tuple=False).squeeze(1)
logits_v = logits[idx]
target_v = target_events[idx].to(torch.long)
dt_v = dt[idx].to(torch.float32)
# Clamp dt into [eps, max_edge) to keep bucket indices valid.
eps = self.eps
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
hazards = F.softplus(logits_v) + eps # (Mv, K, B)
total_hazard = hazards.sum(dim=1) # (Mv, B)
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
widths = edges[1:] - edges[:-1] # (B,)
# Bin index b* in [0, B-1]. boundaries are edges[1:] (length B).
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,)
b_star = torch.clamp(b_star, min=0, max=B - 1)
ar = torch.arange(logits_v.size(0), device=device)
hazard_event = hazards[ar, target_v, b_star] # (Mv,)
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
weighted = total_hazard * widths.unsqueeze(0) # (Mv, B)
cum = weighted.cumsum(dim=1) # (Mv, B)
full_bins_int = torch.zeros_like(dt_v)
has_full = b_star > 0
if has_full.any():
full_bins_int[has_full] = cum.gather(
1, (b_star[has_full] - 1).unsqueeze(1)
).squeeze(1)
edge_left = edges[b_star] # (Mv,)
partial = total_hazard.gather(
1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left)
integral = full_bins_int + partial
nll_v = -torch.log(hazard_event) + integral
nll_full[idx] = nll_v
if reduction == "none":
nll_out = nll_full
elif reduction == "sum":
nll_out = nll_v.sum()
elif reduction == "mean":
nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(())
else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = logits.new_zeros(())
if self.lambda_reg != 0.0:
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
return nll_out, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t.
@@ -207,4 +336,4 @@ class WeibullNLLLoss(nn.Module):
(torch.log(scales + eps) ** 2).mean() +
(torch.log(shapes + eps) ** 2).mean()
)
return nll, reg
return nll, reg