Refactor PiecewiseExponentialLoss for clarity and numerical stability improvements
This commit is contained in:
207
losses.py
207
losses.py
@@ -133,18 +133,13 @@ class ExponentialNLLLoss(nn.Module):
|
||||
|
||||
|
||||
class PiecewiseExponentialLoss(nn.Module):
|
||||
"""Piecewise-constant competing risks exponential likelihood.
|
||||
"""
|
||||
Piecewise-constant competing risks exponential likelihood.
|
||||
|
||||
Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0).
|
||||
Within each bin b, hazards are constant and parameterized as:
|
||||
|
||||
hazards = softplus(logits) + eps with logits shape (M, K, B)
|
||||
|
||||
For each sample i, dt is bucketized to bin b* and the NLL is:
|
||||
|
||||
nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du
|
||||
|
||||
The integral is computed in closed form by summing full bins plus the partial bin b*.
|
||||
Lightweight numerical protections:
|
||||
- Does NOT mask/skip any samples.
|
||||
- Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation.
|
||||
- Clamps logits and dt to keep softplus/log operations finite.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -152,6 +147,7 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
bin_edges: Sequence[float],
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0,
|
||||
logit_clip: float = 30.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -165,6 +161,7 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
|
||||
self.eps = float(eps)
|
||||
self.lambda_reg = float(lambda_reg)
|
||||
self.logit_clip = float(logit_clip)
|
||||
|
||||
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
||||
self.register_buffer("bin_edges", edges, persistent=False)
|
||||
@@ -186,75 +183,87 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
)
|
||||
|
||||
device = logits.device
|
||||
dt = dt.to(device=device)
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
target_events = target_events.to(device=device)
|
||||
|
||||
# Build a per-sample finite mask to avoid NaN/Inf propagation.
|
||||
logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1)
|
||||
dt_finite = torch.isfinite(dt)
|
||||
target_finite = torch.isfinite(target_events)
|
||||
finite_mask = logits_finite & dt_finite & target_finite
|
||||
# No masking/skipping: coerce invalid values to safe defaults.
|
||||
logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
logits_v = torch.clamp(
|
||||
logits_v, min=-self.logit_clip, max=self.logit_clip)
|
||||
|
||||
nll_full = logits.new_zeros((M,))
|
||||
dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
target_v = torch.nan_to_num(
|
||||
target_events, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
target_v = target_v.to(dtype=torch.long)
|
||||
target_v = torch.clamp(target_v, min=0, max=K - 1)
|
||||
|
||||
if not finite_mask.any():
|
||||
nll_out = nll_full if reduction == "none" else logits.new_zeros(())
|
||||
reg_out = logits.new_zeros(())
|
||||
return nll_out, reg_out
|
||||
|
||||
idx = finite_mask.nonzero(as_tuple=False).squeeze(1)
|
||||
logits_v = logits[idx]
|
||||
target_v = target_events[idx].to(torch.long)
|
||||
dt_v = dt[idx].to(torch.float32)
|
||||
|
||||
# Clamp dt into [eps, max_edge) to keep bucket indices valid.
|
||||
# Keep structural clamping to prevent index-out-of-bounds errors
|
||||
# (Necessary for searchsorted/gather to work at all)
|
||||
eps = self.eps
|
||||
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
|
||||
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
|
||||
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
|
||||
|
||||
hazards = F.softplus(logits_v) + eps # (Mv, K, B)
|
||||
total_hazard = hazards.sum(dim=1) # (Mv, B)
|
||||
# Hazards: (M, K, B)
|
||||
hazards = F.softplus(logits_v) + eps
|
||||
hazards = torch.clamp(hazards, min=eps)
|
||||
total_hazard = hazards.sum(dim=1) # (M, B)
|
||||
|
||||
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
|
||||
widths = edges[1:] - edges[:-1] # (B,)
|
||||
|
||||
# Bin index b* in [0, B-1]. boundaries are edges[1:] (length B).
|
||||
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,)
|
||||
# Bin index b* in [0, B-1].
|
||||
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,)
|
||||
b_star = torch.clamp(b_star, min=0, max=B - 1)
|
||||
|
||||
ar = torch.arange(logits_v.size(0), device=device)
|
||||
hazard_event = hazards[ar, target_v, b_star] # (Mv,)
|
||||
# 1. Hazard at event (M,)
|
||||
# gather needs matching dims.
|
||||
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
|
||||
# Alternative: hazards[m, k, b]
|
||||
ar = torch.arange(M, device=device)
|
||||
hazard_event = hazards[ar, target_v, b_star] # (M,)
|
||||
hazard_event = torch.clamp(hazard_event, min=eps)
|
||||
|
||||
# 2. Integral part
|
||||
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
||||
weighted = total_hazard * widths.unsqueeze(0) # (Mv, B)
|
||||
cum = weighted.cumsum(dim=1) # (Mv, B)
|
||||
full_bins_int = torch.zeros_like(dt_v)
|
||||
has_full = b_star > 0
|
||||
if has_full.any():
|
||||
full_bins_int[has_full] = cum.gather(
|
||||
1, (b_star[has_full] - 1).unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
edge_left = edges[b_star] # (Mv,)
|
||||
partial = total_hazard.gather(
|
||||
1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left)
|
||||
# Full bins accumulation
|
||||
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
|
||||
cum = weighted.cumsum(dim=1) # (M, B)
|
||||
|
||||
full_bins_int = torch.zeros_like(dt_v)
|
||||
|
||||
# We process 'has_full' logic generally.
|
||||
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
|
||||
has_full = b_star > 0
|
||||
|
||||
# NOTE: Even without protection, we need valid indices for gather.
|
||||
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
|
||||
safe_indices = (b_star - 1).clamp(min=0)
|
||||
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
|
||||
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
|
||||
|
||||
# Partial bin accumulation
|
||||
edge_left = edges[b_star] # (M,)
|
||||
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
|
||||
partial = partial_hazard * (dt_v - edge_left)
|
||||
|
||||
integral = full_bins_int + partial
|
||||
|
||||
nll_v = -torch.log(hazard_event) + integral
|
||||
nll_full[idx] = nll_v
|
||||
# Final NLL
|
||||
nll = -torch.log(hazard_event) + integral
|
||||
|
||||
# Reduction
|
||||
if reduction == "none":
|
||||
nll_out = nll_full
|
||||
nll_out = nll
|
||||
elif reduction == "sum":
|
||||
nll_out = nll_v.sum()
|
||||
nll_out = nll.sum()
|
||||
elif reduction == "mean":
|
||||
nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(())
|
||||
nll_out = nll.mean()
|
||||
else:
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
|
||||
reg = logits.new_zeros(())
|
||||
|
||||
if self.lambda_reg != 0.0:
|
||||
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
|
||||
|
||||
@@ -263,77 +272,83 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
|
||||
class WeibullNLLLoss(nn.Module):
|
||||
"""
|
||||
Weibull hazard in t.
|
||||
Weibull hazard in t with lightweight numerical protections.
|
||||
|
||||
.. math::
|
||||
\\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k}
|
||||
|
||||
\\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1}
|
||||
|
||||
Args:
|
||||
eps (float): Small epsilon for numerical stability.
|
||||
lambda_reg (float): Regularization weight.
|
||||
use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples.
|
||||
near_integer_eps_years (float): Near-integer threshold in years.
|
||||
interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years.
|
||||
min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year.
|
||||
Does NOT mask/skip any samples. Instead:
|
||||
- nan_to_num for logits/dt/targets
|
||||
- clamps logits to keep softplus outputs reasonable
|
||||
- computes t^shape in log-space with clamped exponent to prevent overflow
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0,
|
||||
logit_clip: float = 30.0,
|
||||
max_shape: float = 30.0,
|
||||
max_dt: float = 1.0e3,
|
||||
max_exp: float = 80.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.lambda_reg = lambda_reg
|
||||
self.logit_clip = float(logit_clip)
|
||||
self.max_shape = float(max_shape)
|
||||
self.max_dt = float(max_dt)
|
||||
self.max_exp = float(max_exp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass.
|
||||
def forward(self, logits, target_events, dt, reduction="mean"):
|
||||
if logits.dim() != 3 or logits.size(-1) != 2:
|
||||
raise ValueError("logits must have shape (M, K, 2)")
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (M, K, 2) tensor of logits.
|
||||
target_events (torch.Tensor): (M,) tensor of target events.
|
||||
dt (torch.Tensor): (M,) tensor of time intervals.
|
||||
reduction (str): 'mean', 'sum', or 'none'.
|
||||
M, K, _ = logits.shape
|
||||
device = logits.device
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization).
|
||||
"""
|
||||
shapes = F.softplus(logits[..., 0]) + self.eps # (M,K)
|
||||
scales = F.softplus(logits[..., 1]) + self.eps # (M,K)
|
||||
eps = self.eps
|
||||
t = torch.clamp(dt, min=eps)
|
||||
logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip)
|
||||
|
||||
t_mat = t.unsqueeze(1) # (M,1)
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
dt = torch.clamp(dt, min=self.eps, max=self.max_dt)
|
||||
|
||||
# cumulative hazard (M,K)
|
||||
cum_hazard = scales * t_mat.pow(shapes)
|
||||
target_events = target_events.to(device=device)
|
||||
target_events = torch.nan_to_num(
|
||||
target_events, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
target_events = target_events.to(dtype=torch.long)
|
||||
target_events = torch.clamp(target_events, min=0, max=K - 1)
|
||||
|
||||
# hazard (M,K)
|
||||
hazard = shapes * scales * t_mat.pow(shapes - 1.0)
|
||||
shapes = F.softplus(logits[..., 0]) + self.eps
|
||||
scales = F.softplus(logits[..., 1]) + self.eps
|
||||
shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape)
|
||||
scales = torch.clamp(scales, min=self.eps)
|
||||
|
||||
t_mat = dt.unsqueeze(1) # (M,1)
|
||||
log_t = torch.log(torch.clamp(t_mat, min=self.eps))
|
||||
|
||||
# Compute t^shape and t^(shape-1) in log-space with exponent clamp.
|
||||
pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp))
|
||||
pow_shape_minus_1 = torch.exp(
|
||||
torch.clamp((shapes - 1.0) * log_t, max=self.max_exp)
|
||||
)
|
||||
|
||||
cum_hazard = scales * pow_shape
|
||||
hazard = shapes * scales * pow_shape_minus_1
|
||||
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
||||
# Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t))
|
||||
# NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t)
|
||||
nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1)
|
||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||
|
||||
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
|
||||
|
||||
if reduction == "mean":
|
||||
nll = nll.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll.sum()
|
||||
elif reduction != "none":
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
|
||||
reg = shapes.new_zeros(())
|
||||
if self.lambda_reg > 0:
|
||||
reg = self.lambda_reg * (
|
||||
(torch.log(scales + eps) ** 2).mean() +
|
||||
(torch.log(shapes + eps) ** 2).mean()
|
||||
(torch.log(scales + self.eps) ** 2).mean() +
|
||||
(torch.log(shapes + self.eps) ** 2).mean()
|
||||
)
|
||||
return nll, reg
|
||||
|
||||
Reference in New Issue
Block a user