Files
DeepHealth/losses.py

210 lines
6.7 KiB
Python
Raw Normal View History

import math
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Pair extraction (utility; not used by the losses below)
# ============================================================
def get_valid_pairs_and_dt(
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
n_tech_tokens: int
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Extract valid event pairs (prev -> next) and compute dt in years.
Args:
event_seqs (torch.Tensor): Event sequences.
time_seqs (torch.Tensor): Time sequences.
n_tech_tokens (int): Number of technical tokens.
Returns:
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
Notes:
- Assumes strict right-padding.
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
- Filters to strictly positive dt.
"""
real_mask = event_seqs >= 1
idx = real_mask.nonzero(as_tuple=False)
if idx.size(0) <= 1:
return None
same_batch = idx[1:, 0] == idx[:-1, 0]
if not same_batch.any():
return None
prev_idx = idx[:-1][same_batch]
next_idx = idx[1:][same_batch]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
if not valid_target.any():
return None
prev_idx = prev_idx[valid_target]
next_idx = next_idx[valid_target]
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
dt = (time_seqs[b_next, t_next] -
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
valid_dt = dt > 0
if not valid_dt.any():
return None
dt = dt[valid_dt]
b_prev = b_prev[valid_dt]
t_prev = t_prev[valid_dt]
b_next = b_next[valid_dt]
t_next = t_next[valid_dt]
return dt, b_prev, t_prev, b_next, t_next
# ============================================================
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
# ============================================================
class ExponentialNLLLoss(nn.Module):
"""
Competing risks exponential likelihood.
The negative log-likelihood is given by:
.. math::
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
Args:
eps (float): Small epsilon for numerical stability.
"""
def __init__(
self,
lambda_reg: float = 0.0,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K) tensor of logits.
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
"""
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
hazards = F.softplus(logits) + self.eps # (M,K)
hazard_event = hazards.gather(
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
total_hazard = hazards.sum(dim=1) # (M,)
log_hazards = torch.log(hazards) # (M, K)
nll = -torch.log(hazard_event) + total_hazard * dt
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = F.cross_entropy(log_hazards, target_events,
reduction="mean") * self.lambda_reg
return nll, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t.
.. math::
\\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k}
\\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1}
Args:
eps (float): Small epsilon for numerical stability.
lambda_reg (float): Regularization weight.
use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples.
near_integer_eps_years (float): Near-integer threshold in years.
interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years.
min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year.
"""
def __init__(
self,
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K, 2) tensor of logits.
target_events (torch.Tensor): (M,) tensor of target events.
dt (torch.Tensor): (M,) tensor of time intervals.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization).
"""
shapes = F.softplus(logits[..., 0]) + self.eps # (M,K)
scales = F.softplus(logits[..., 1]) + self.eps # (M,K)
eps = self.eps
t = torch.clamp(dt, min=eps)
t_mat = t.unsqueeze(1) # (M,1)
# cumulative hazard (M,K)
cum_hazard = scales * t_mat.pow(shapes)
# hazard (M,K)
hazard = shapes * scales * t_mat.pow(shapes - 1.0)
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
# Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t))
# NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t)
nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1)
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = shapes.new_zeros(())
if self.lambda_reg > 0:
reg = self.lambda_reg * (
(torch.log(scales + eps) ** 2).mean() +
(torch.log(shapes + eps) ** 2).mean()
)
return nll, reg