Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.
This commit is contained in:
222
losses.py
222
losses.py
@@ -132,9 +132,19 @@ class ExponentialNLLLoss(nn.Module):
|
||||
return nll, reg
|
||||
|
||||
|
||||
class PiecewiseExponentialLoss(nn.Module):
|
||||
"""
|
||||
Piecewise-constant competing risks exponential likelihood.
|
||||
class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
||||
|
||||
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
|
||||
channels, where the complement channel (index K) represents survival across bins.
|
||||
|
||||
Per-sample likelihood for observed cause k at time bin j:
|
||||
p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j)
|
||||
|
||||
Args:
|
||||
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
|
||||
eps: Unused; kept for interface compatibility / future numerical tweaks.
|
||||
lambda_reg: Optional regularization strength.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -146,18 +156,20 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if len(bin_edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2")
|
||||
if bin_edges[0] != 0:
|
||||
raise ValueError("bin_edges must start at 0")
|
||||
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
||||
if float(bin_edges[0]) != 0.0:
|
||||
raise ValueError("bin_edges[0] must equal 0")
|
||||
for i in range(1, len(bin_edges)):
|
||||
if not (bin_edges[i] > bin_edges[i - 1]):
|
||||
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
|
||||
self.eps = float(eps)
|
||||
self.lambda_reg = float(lambda_reg)
|
||||
|
||||
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
||||
self.register_buffer("bin_edges", edges, persistent=False)
|
||||
self.register_buffer(
|
||||
"bin_edges",
|
||||
torch.tensor(bin_edges, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -166,145 +178,83 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("logits must have shape (M, K, B)")
|
||||
|
||||
M, K, B = logits.shape
|
||||
if self.bin_edges.numel() != B + 1:
|
||||
if logits.ndim != 3:
|
||||
raise ValueError(
|
||||
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
|
||||
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
||||
)
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError(
|
||||
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
||||
)
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
||||
)
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
device = logits.device
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
target_events = target_events.to(device=device)
|
||||
if not torch.all(dt > 0):
|
||||
raise ValueError("dt must be strictly positive")
|
||||
|
||||
# Infer K and n_bins from logits and bin_edges.
|
||||
m, k_plus_1, n_bins_plus_1 = logits.shape
|
||||
k_comp = k_plus_1 - 1
|
||||
if k_comp < 1:
|
||||
raise ValueError(
|
||||
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
|
||||
|
||||
n_bins = int(self.bin_edges.numel() - 1)
|
||||
if n_bins_plus_1 != n_bins + 1:
|
||||
raise ValueError(
|
||||
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
||||
)
|
||||
|
||||
if target_events.dtype != torch.long:
|
||||
target_events = target_events.to(dtype=torch.long)
|
||||
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||
raise ValueError("target_events must be in [0, K)")
|
||||
target_events = target_events.to(torch.long)
|
||||
|
||||
# Hazards: (M, K, B)
|
||||
hazards = F.softplus(logits) + self.eps
|
||||
total_hazard = hazards.sum(dim=1) # (M, B)
|
||||
if (target_events < 0).any() or (target_events >= k_comp).any():
|
||||
raise ValueError(
|
||||
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
|
||||
)
|
||||
|
||||
edges = self.bin_edges.to(device=device, dtype=dt.dtype)
|
||||
widths = edges[1:] - edges[:-1] # (B,)
|
||||
# Map continuous dt to discrete bins j in {1..n_bins}.
|
||||
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
||||
# (M,), may be n_bins+1 if dt > bin_edges[-1]
|
||||
time_bin = torch.bucketize(dt, bin_edges)
|
||||
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
|
||||
torch.long) # ensure valid event bins
|
||||
|
||||
if dt.min().item() <= 0:
|
||||
raise ValueError("dt must be strictly positive")
|
||||
if dt.max().item() > edges[-1].item():
|
||||
raise ValueError("dt must be <= last bin edge")
|
||||
# Log-probabilities across causes+complement for each bin.
|
||||
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
||||
|
||||
# Bin index b* in [0, B-1].
|
||||
b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,)
|
||||
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
|
||||
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
|
||||
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
|
||||
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
|
||||
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
|
||||
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
|
||||
|
||||
# 1. Hazard at event (M,)
|
||||
# gather needs matching dims.
|
||||
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
|
||||
# Alternative: hazards[m, k, b]
|
||||
ar = torch.arange(M, device=device)
|
||||
hazard_event = hazards[ar, target_events, b_star] # (M,)
|
||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||
# Event term at bin j: -log p(k at j)
|
||||
m_idx = torch.arange(m, device=logits.device)
|
||||
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
|
||||
|
||||
# 2. Integral part
|
||||
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
||||
|
||||
# Full bins accumulation
|
||||
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
|
||||
cum = weighted.cumsum(dim=1) # (M, B)
|
||||
|
||||
full_bins_int = torch.zeros_like(dt)
|
||||
|
||||
# We process 'has_full' logic generally.
|
||||
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
|
||||
has_full = b_star > 0
|
||||
|
||||
# NOTE: Even without protection, we need valid indices for gather.
|
||||
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
|
||||
safe_indices = (b_star - 1).clamp(min=0)
|
||||
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
|
||||
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
|
||||
|
||||
# Partial bin accumulation
|
||||
edge_left = edges[b_star] # (M,)
|
||||
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
|
||||
partial = partial_hazard * (dt - edge_left)
|
||||
|
||||
integral = full_bins_int + partial
|
||||
|
||||
# Final NLL
|
||||
nll = -torch.log(hazard_event) + integral
|
||||
|
||||
# Reduction
|
||||
if reduction == "none":
|
||||
nll_out = nll
|
||||
elif reduction == "sum":
|
||||
nll_out = nll.sum()
|
||||
elif reduction == "mean":
|
||||
nll_out = nll.mean()
|
||||
else:
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
|
||||
reg = logits.new_zeros(())
|
||||
if self.lambda_reg != 0.0:
|
||||
reg = reg + (self.lambda_reg * logits.pow(2).mean())
|
||||
|
||||
return nll_out, reg
|
||||
|
||||
|
||||
class WeibullNLLLoss(nn.Module):
|
||||
"""
|
||||
Weibull hazard in t.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.lambda_reg = lambda_reg
|
||||
|
||||
def forward(self, logits, target_events, dt, reduction="mean"):
|
||||
if logits.dim() != 3 or logits.size(-1) != 2:
|
||||
raise ValueError("logits must have shape (M, K, 2)")
|
||||
|
||||
M, K, _ = logits.shape
|
||||
device = logits.device
|
||||
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
if dt.min().item() <= 0:
|
||||
raise ValueError("dt must be strictly positive")
|
||||
|
||||
target_events = target_events.to(device=device)
|
||||
target_events = target_events.to(dtype=torch.long)
|
||||
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||
raise ValueError("target_events must be in [0, K)")
|
||||
|
||||
shapes = F.softplus(logits[..., 0]) + self.eps
|
||||
scales = F.softplus(logits[..., 1]) + self.eps
|
||||
|
||||
t_mat = dt.unsqueeze(1) # (M,1)
|
||||
cum_hazard = scales * torch.pow(t_mat, shapes)
|
||||
hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0)
|
||||
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||
|
||||
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
|
||||
loss = loss_prev + loss_event
|
||||
|
||||
if reduction == "mean":
|
||||
nll = nll.mean()
|
||||
nll = loss.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll.sum()
|
||||
elif reduction != "none":
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
nll = loss.sum()
|
||||
else:
|
||||
nll = loss
|
||||
|
||||
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
|
||||
if self.lambda_reg > 0.0:
|
||||
# Regularize the cause distribution at the event bin using NLL on log-probs.
|
||||
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
|
||||
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
|
||||
logp_at_event_bin = logp_causes.gather(
|
||||
dim=2, index=idx).squeeze(2) # (M, K)
|
||||
reg = self.lambda_reg * \
|
||||
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
||||
|
||||
reg = shapes.new_zeros(())
|
||||
if self.lambda_reg > 0:
|
||||
reg = self.lambda_reg * (
|
||||
(torch.log(scales + self.eps) ** 2).mean() +
|
||||
(torch.log(shapes + self.eps) ** 2).mean()
|
||||
)
|
||||
return nll, reg
|
||||
|
||||
Reference in New Issue
Block a user