Files
DeepHealth/losses.py

355 lines
12 KiB
Python
Raw Normal View History

import math
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Pair extraction (utility; not used by the losses below)
# ============================================================
def get_valid_pairs_and_dt(
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
n_tech_tokens: int
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Extract valid event pairs (prev -> next) and compute dt in years.
Args:
event_seqs (torch.Tensor): Event sequences.
time_seqs (torch.Tensor): Time sequences.
n_tech_tokens (int): Number of technical tokens.
Returns:
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
Notes:
- Assumes strict right-padding.
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
- Filters to strictly positive dt.
"""
real_mask = event_seqs >= 1
idx = real_mask.nonzero(as_tuple=False)
if idx.size(0) <= 1:
return None
same_batch = idx[1:, 0] == idx[:-1, 0]
if not same_batch.any():
return None
prev_idx = idx[:-1][same_batch]
next_idx = idx[1:][same_batch]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
if not valid_target.any():
return None
prev_idx = prev_idx[valid_target]
next_idx = next_idx[valid_target]
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
dt = (time_seqs[b_next, t_next] -
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
valid_dt = dt > 0
if not valid_dt.any():
return None
dt = dt[valid_dt]
b_prev = b_prev[valid_dt]
t_prev = t_prev[valid_dt]
b_next = b_next[valid_dt]
t_next = t_next[valid_dt]
return dt, b_prev, t_prev, b_next, t_next
# ============================================================
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
# ============================================================
class ExponentialNLLLoss(nn.Module):
"""
Competing risks exponential likelihood.
The negative log-likelihood is given by:
.. math::
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
Args:
eps (float): Small epsilon for numerical stability.
"""
def __init__(
self,
lambda_reg: float = 0.0,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K) tensor of logits.
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
"""
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
hazards = F.softplus(logits) + self.eps # (M,K)
hazard_event = hazards.gather(
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
total_hazard = hazards.sum(dim=1) # (M,)
log_hazards = torch.log(hazards) # (M, K)
nll = -torch.log(hazard_event) + total_hazard * dt
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = F.cross_entropy(log_hazards, target_events,
reduction="mean") * self.lambda_reg
return nll, reg
class PiecewiseExponentialLoss(nn.Module):
"""
Piecewise-constant competing risks exponential likelihood.
Lightweight numerical protections:
- Does NOT mask/skip any samples.
- Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation.
- Clamps logits and dt to keep softplus/log operations finite.
"""
def __init__(
self,
bin_edges: Sequence[float],
eps: float = 1e-6,
lambda_reg: float = 0.0,
logit_clip: float = 30.0,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
if bin_edges[0] != 0:
raise ValueError("bin_edges must start at 0")
for i in range(1, len(bin_edges)):
if not (bin_edges[i] > bin_edges[i - 1]):
raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps)
self.lambda_reg = float(lambda_reg)
self.logit_clip = float(logit_clip)
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
self.register_buffer("bin_edges", edges, persistent=False)
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.dim() != 3:
raise ValueError("logits must have shape (M, K, B)")
M, K, B = logits.shape
if self.bin_edges.numel() != B + 1:
raise ValueError(
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
)
device = logits.device
dt = dt.to(device=device, dtype=torch.float32)
target_events = target_events.to(device=device)
# No masking/skipping: coerce invalid values to safe defaults.
logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
logits_v = torch.clamp(
logits_v, min=-self.logit_clip, max=self.logit_clip)
dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
target_v = torch.nan_to_num(
target_events, nan=0.0, posinf=0.0, neginf=0.0)
target_v = target_v.to(dtype=torch.long)
target_v = torch.clamp(target_v, min=0, max=K - 1)
# Keep structural clamping to prevent index-out-of-bounds errors
# (Necessary for searchsorted/gather to work at all)
eps = self.eps
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
# Hazards: (M, K, B)
hazards = F.softplus(logits_v) + eps
hazards = torch.clamp(hazards, min=eps)
total_hazard = hazards.sum(dim=1) # (M, B)
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
widths = edges[1:] - edges[:-1] # (B,)
# Bin index b* in [0, B-1].
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,)
b_star = torch.clamp(b_star, min=0, max=B - 1)
# 1. Hazard at event (M,)
# gather needs matching dims.
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
# Alternative: hazards[m, k, b]
ar = torch.arange(M, device=device)
hazard_event = hazards[ar, target_v, b_star] # (M,)
hazard_event = torch.clamp(hazard_event, min=eps)
# 2. Integral part
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
# Full bins accumulation
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
cum = weighted.cumsum(dim=1) # (M, B)
full_bins_int = torch.zeros_like(dt_v)
# We process 'has_full' logic generally.
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
has_full = b_star > 0
# NOTE: Even without protection, we need valid indices for gather.
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
safe_indices = (b_star - 1).clamp(min=0)
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
# Partial bin accumulation
edge_left = edges[b_star] # (M,)
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
partial = partial_hazard * (dt_v - edge_left)
integral = full_bins_int + partial
# Final NLL
nll = -torch.log(hazard_event) + integral
# Reduction
if reduction == "none":
nll_out = nll
elif reduction == "sum":
nll_out = nll.sum()
elif reduction == "mean":
nll_out = nll.mean()
else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = logits.new_zeros(())
if self.lambda_reg != 0.0:
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
return nll_out, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t with lightweight numerical protections.
Does NOT mask/skip any samples. Instead:
- nan_to_num for logits/dt/targets
- clamps logits to keep softplus outputs reasonable
- computes t^shape in log-space with clamped exponent to prevent overflow
"""
def __init__(
self,
eps: float = 1e-6,
lambda_reg: float = 0.0,
logit_clip: float = 30.0,
max_shape: float = 30.0,
max_dt: float = 1.0e3,
max_exp: float = 80.0,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
self.logit_clip = float(logit_clip)
self.max_shape = float(max_shape)
self.max_dt = float(max_dt)
self.max_exp = float(max_exp)
def forward(self, logits, target_events, dt, reduction="mean"):
if logits.dim() != 3 or logits.size(-1) != 2:
raise ValueError("logits must have shape (M, K, 2)")
M, K, _ = logits.shape
device = logits.device
logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip)
dt = dt.to(device=device, dtype=torch.float32)
dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
dt = torch.clamp(dt, min=self.eps, max=self.max_dt)
target_events = target_events.to(device=device)
target_events = torch.nan_to_num(
target_events, nan=0.0, posinf=0.0, neginf=0.0)
target_events = target_events.to(dtype=torch.long)
target_events = torch.clamp(target_events, min=0, max=K - 1)
shapes = F.softplus(logits[..., 0]) + self.eps
scales = F.softplus(logits[..., 1]) + self.eps
shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape)
scales = torch.clamp(scales, min=self.eps)
t_mat = dt.unsqueeze(1) # (M,1)
log_t = torch.log(torch.clamp(t_mat, min=self.eps))
# Compute t^shape and t^(shape-1) in log-space with exponent clamp.
pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp))
pow_shape_minus_1 = torch.exp(
torch.clamp((shapes - 1.0) * log_t, max=self.max_exp)
)
cum_hazard = scales * pow_shape
hazard = shapes * scales * pow_shape_minus_1
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
hazard_event = torch.clamp(hazard_event, min=self.eps)
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
elif reduction != "none":
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = shapes.new_zeros(())
if self.lambda_reg > 0:
reg = self.lambda_reg * (
(torch.log(scales + self.eps) ** 2).mean() +
(torch.log(shapes + self.eps) ** 2).mean()
)
return nll, reg