Files
DeepHealth/losses.py

407 lines
15 KiB
Python
Raw Normal View History

import math
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Pair extraction (utility; not used by the losses below)
# ============================================================
def get_valid_pairs_and_dt(
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
n_tech_tokens: int
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Extract valid event pairs (prev -> next) and compute dt in years.
Args:
event_seqs (torch.Tensor): Event sequences.
time_seqs (torch.Tensor): Time sequences.
n_tech_tokens (int): Number of technical tokens.
Returns:
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
Notes:
- Assumes strict right-padding.
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
- Filters to strictly positive dt.
"""
real_mask = event_seqs >= 1
idx = real_mask.nonzero(as_tuple=False)
if idx.size(0) <= 1:
return None
same_batch = idx[1:, 0] == idx[:-1, 0]
if not same_batch.any():
return None
prev_idx = idx[:-1][same_batch]
next_idx = idx[1:][same_batch]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
if not valid_target.any():
return None
prev_idx = prev_idx[valid_target]
next_idx = next_idx[valid_target]
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
dt = (time_seqs[b_next, t_next] -
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
valid_dt = dt > 0
if not valid_dt.any():
return None
dt = dt[valid_dt]
b_prev = b_prev[valid_dt]
t_prev = t_prev[valid_dt]
b_next = b_next[valid_dt]
t_next = t_next[valid_dt]
return dt, b_prev, t_prev, b_next, t_next
# ============================================================
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
# ============================================================
class ExponentialNLLLoss(nn.Module):
"""
Competing risks exponential likelihood.
The negative log-likelihood is given by:
.. math::
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
Args:
eps (float): Small epsilon for numerical stability.
"""
def __init__(
self,
lambda_reg: float = 0.0,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K) tensor of logits.
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
"""
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
hazards = F.softplus(logits) + self.eps # (M,K)
hazard_event = hazards.gather(
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
total_hazard = hazards.sum(dim=1) # (M,)
log_hazards = torch.log(hazards) # (M, K)
nll = -torch.log(hazard_event) + total_hazard * dt
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = F.cross_entropy(log_hazards, target_events,
reduction="mean") * self.lambda_reg
return nll, reg
class DiscreteTimeCIFNLLLoss(nn.Module):
"""Direct discrete-time CIF negative log-likelihood (no censoring).
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
channels, where the complement channel (index K) represents survival across bins.
Per-sample likelihood for observed cause k at time bin j:
p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j)
Args:
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
eps: Unused; kept for interface compatibility / future numerical tweaks.
lambda_reg: Optional regularization strength.
"""
def __init__(
self,
bin_edges: Sequence[float],
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
if float(bin_edges[0]) != 0.0:
raise ValueError("bin_edges[0] must equal 0")
for i in range(1, len(bin_edges)):
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps)
self.lambda_reg = float(lambda_reg)
self.register_buffer(
"bin_edges",
torch.tensor(bin_edges, dtype=torch.float32),
persistent=False,
)
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.ndim != 3:
raise ValueError(
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
)
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError(
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
)
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
)
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
if not torch.all(dt > 0):
raise ValueError("dt must be strictly positive")
# Infer K and n_bins from logits and bin_edges.
m, k_plus_1, n_bins_plus_1 = logits.shape
k_comp = k_plus_1 - 1
if k_comp < 1:
raise ValueError(
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
n_bins = int(self.bin_edges.numel() - 1)
if n_bins_plus_1 != n_bins + 1:
raise ValueError(
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
)
if target_events.dtype != torch.long:
target_events = target_events.to(torch.long)
if (target_events < 0).any() or (target_events >= k_comp).any():
raise ValueError(
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
)
# Map continuous dt to discrete bins j in {1..n_bins}.
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
# (M,), may be n_bins+1 if dt > bin_edges[-1]
time_bin = torch.bucketize(dt, bin_edges)
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
torch.long) # ensure valid event bins
# Log-probabilities across causes+complement for each bin.
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
# Event term at bin j: -log p(k at j)
m_idx = torch.arange(m, device=logits.device)
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
loss = loss_prev + loss_event
if reduction == "mean":
nll = loss.mean()
elif reduction == "sum":
nll = loss.sum()
else:
nll = loss
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
if self.lambda_reg > 0.0:
# Regularize the cause distribution at the event bin using NLL on log-probs.
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
logp_at_event_bin = logp_causes.gather(
dim=2, index=idx).squeeze(2) # (M, K)
reg = self.lambda_reg * \
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
return nll, reg
class PiecewiseExponentialCIFNLLLoss(nn.Module):
"""
Piecewise-Exponential (PWE) cause-specific hazards with discrete-time CIF likelihood.
- No censoring
- No regularization (reg always 0)
- Forward signature matches DiscreteTimeCIFNLLLoss:
forward(logits, target_events, dt, reduction) -> (nll, reg)
Expected shapes:
logits: (M, K, n_bins) # hazard logits per cause per bin
target_events: (M,) long in [0, K-1]
dt: (M,) event times (strictly > 0)
bin_edges:
length n_bins+1, strictly increasing, bin_edges[0]==0,
and MUST be finite at the last edge (no +inf) for PWE.
"""
def __init__(
self,
bin_edges: Sequence[float],
eps: float = 1e-6,
lambda_reg: float = 0.0, # kept for signature compatibility; UNUSED
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
if float(bin_edges[0]) != 0.0:
raise ValueError("bin_edges[0] must equal 0.0")
for i in range(1, len(bin_edges)):
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
raise ValueError("bin_edges must be strictly increasing")
if math.isinf(float(bin_edges[-1])):
raise ValueError(
"PiecewiseExponentialCIFNLLLoss requires a finite last bin edge (no +inf). "
"Use a finite truncation horizon for PWE."
)
self.eps = float(eps)
# unused, kept only for interface compatibility
self.lambda_reg = float(lambda_reg)
self.register_buffer(
"bin_edges",
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
persistent=False,
)
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
if logits.ndim != 3:
raise ValueError(
f"logits must be 3D (M, K, n_bins); got shape={tuple(logits.shape)}")
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError("target_events and dt must be 1D tensors")
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch among logits, target_events, dt")
if not torch.all(dt > 0):
raise ValueError(
"dt must be strictly positive (no censoring supported here)")
M, K, n_bins = logits.shape
if target_events.dtype != torch.long:
target_events = target_events.to(torch.long)
if (target_events < 0).any() or (target_events >= K).any():
raise ValueError(f"target_events must be in [0, {K-1}]")
# Prepare bin_edges / bin widths
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
if bin_edges.numel() != n_bins + 1:
raise ValueError(
f"bin_edges length must be n_bins+1={n_bins+1}; got {bin_edges.numel()}"
)
dt_bins = (bin_edges[1:] - bin_edges[:-1]
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
if not torch.isfinite(dt_bins).all():
raise ValueError("All bin widths must be finite for PWE.")
if not (dt_bins > 0).all():
raise ValueError(
"All bin widths must be strictly positive for PWE.")
# Map event time -> bin index k* in {1..n_bins}
# (same convention as your discrete_time_cif: clamp to [1, n_bins])
time_bin = torch.bucketize(dt, bin_edges)
time_bin = torch.clamp(
time_bin, min=1, max=n_bins).to(torch.long) # (M,)
k0 = time_bin - 1 # 0..n_bins-1
# Nonnegative hazards per cause per bin
hazards = F.softplus(logits) + self.eps # (M, K, n_bins)
# Integrated hazards H_{j,k} = lambda_{j,k} * Δt_k
H_jk = hazards * dt_bins.view(1, 1, n_bins) # (M, K, n_bins)
H_k = H_jk.sum(dim=1) # (M, n_bins)
# Previous survival term: Σ_{u<k*} H_u
bins = torch.arange(
1, n_bins + 1, device=logits.device).unsqueeze(0) # (1, n_bins)
mask_prev = bins < time_bin.unsqueeze(1) # (M, n_bins)
loss_prev = (H_k * mask_prev.to(H_k.dtype)).sum(dim=1) # (M,)
# Event term at k*: -log p_{k*}(cause)
m_idx = torch.arange(M, device=logits.device)
H_event_total = torch.clamp(H_k[m_idx, k0], min=self.eps) # (M,)
H_event_cause = torch.clamp(
H_jk[m_idx, target_events, k0], min=self.eps) # (M,)
# log(1 - exp(-H)) stable
log1mexp = torch.log(-torch.expm1(-H_event_total)) # (M,)
loss_event = -log1mexp - \
torch.log(H_event_cause) + torch.log(H_event_total)
loss_vec = loss_prev + loss_event # (M,)
if reduction == "mean":
nll = loss_vec.mean()
elif reduction == "sum":
nll = loss_vec.sum()
else:
nll = loss_vec
if self.lambda_reg > 0.0 and n_bins >= 3:
log_h = torch.log(hazards) # (M, K, n_bins)
d2 = log_h[:, :, 2:] - 2.0 * log_h[:, :, 1:-1] + \
log_h[:, :, :-2] # (M, K, n_bins-2)
reg = self.lambda_reg * (d2.pow(2).mean())
else:
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
return nll, reg