From 06a01d2893f01eb80b2a5b72133d372194e0d4ed Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 8 Jan 2026 12:45:31 +0800 Subject: [PATCH] Add PiecewiseExponentialLoss class and update TrainConfig for new loss type --- losses.py | 131 +++++++++++++++++++++++++++++++++++++++++++++++++++++- train.py | 23 ++++++---- 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/losses.py b/losses.py index 94daadc..3b54c49 100644 --- a/losses.py +++ b/losses.py @@ -132,6 +132,135 @@ class ExponentialNLLLoss(nn.Module): return nll, reg +class PiecewiseExponentialLoss(nn.Module): + """Piecewise-constant competing risks exponential likelihood. + + Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0). + Within each bin b, hazards are constant and parameterized as: + + hazards = softplus(logits) + eps with logits shape (M, K, B) + + For each sample i, dt is bucketized to bin b* and the NLL is: + + nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du + + The integral is computed in closed form by summing full bins plus the partial bin b*. + """ + + def __init__( + self, + bin_edges: Sequence[float], + eps: float = 1e-6, + lambda_reg: float = 0.0, + ): + super().__init__() + + if len(bin_edges) < 2: + raise ValueError("bin_edges must have length >= 2") + if bin_edges[0] != 0: + raise ValueError("bin_edges must start at 0") + for i in range(1, len(bin_edges)): + if not (bin_edges[i] > bin_edges[i - 1]): + raise ValueError("bin_edges must be strictly increasing") + + self.eps = float(eps) + self.lambda_reg = float(lambda_reg) + + edges = torch.tensor(list(bin_edges), dtype=torch.float32) + self.register_buffer("bin_edges", edges, persistent=False) + + def forward( + self, + logits: torch.Tensor, + target_events: torch.Tensor, + dt: torch.Tensor, + reduction: str = "mean", + ) -> Tuple[torch.Tensor, torch.Tensor]: + if logits.dim() != 3: + raise ValueError("logits must have shape (M, K, B)") + + M, K, B = logits.shape + if self.bin_edges.numel() != B + 1: + raise ValueError( + f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})" + ) + + device = logits.device + dt = dt.to(device=device) + target_events = target_events.to(device=device) + + # Build a per-sample finite mask to avoid NaN/Inf propagation. + logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1) + dt_finite = torch.isfinite(dt) + target_finite = torch.isfinite(target_events) + finite_mask = logits_finite & dt_finite & target_finite + + nll_full = logits.new_zeros((M,)) + + if not finite_mask.any(): + nll_out = nll_full if reduction == "none" else logits.new_zeros(()) + reg_out = logits.new_zeros(()) + return nll_out, reg_out + + idx = finite_mask.nonzero(as_tuple=False).squeeze(1) + logits_v = logits[idx] + target_v = target_events[idx].to(torch.long) + dt_v = dt[idx].to(torch.float32) + + # Clamp dt into [eps, max_edge) to keep bucket indices valid. + eps = self.eps + max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype) + dt_max = torch.nextafter(max_edge, max_edge.new_zeros(())) + dt_v = torch.clamp(dt_v, min=eps, max=dt_max) + + hazards = F.softplus(logits_v) + eps # (Mv, K, B) + total_hazard = hazards.sum(dim=1) # (Mv, B) + + edges = self.bin_edges.to(device=device, dtype=dt_v.dtype) + widths = edges[1:] - edges[:-1] # (B,) + + # Bin index b* in [0, B-1]. boundaries are edges[1:] (length B). + b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,) + b_star = torch.clamp(b_star, min=0, max=B - 1) + + ar = torch.arange(logits_v.size(0), device=device) + hazard_event = hazards[ar, target_v, b_star] # (Mv,) + + # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) + weighted = total_hazard * widths.unsqueeze(0) # (Mv, B) + cum = weighted.cumsum(dim=1) # (Mv, B) + full_bins_int = torch.zeros_like(dt_v) + has_full = b_star > 0 + if has_full.any(): + full_bins_int[has_full] = cum.gather( + 1, (b_star[has_full] - 1).unsqueeze(1) + ).squeeze(1) + + edge_left = edges[b_star] # (Mv,) + partial = total_hazard.gather( + 1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left) + integral = full_bins_int + partial + + nll_v = -torch.log(hazard_event) + integral + nll_full[idx] = nll_v + + if reduction == "none": + nll_out = nll_full + elif reduction == "sum": + nll_out = nll_v.sum() + elif reduction == "mean": + nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(()) + else: + raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") + + reg = logits.new_zeros(()) + + if self.lambda_reg != 0.0: + reg = reg + (self.lambda_reg * logits_v.pow(2).mean()) + + return nll_out, reg + + class WeibullNLLLoss(nn.Module): """ Weibull hazard in t. @@ -207,4 +336,4 @@ class WeibullNLLLoss(nn.Module): (torch.log(scales + eps) ** 2).mean() + (torch.log(shapes + eps) ** 2).mean() ) - return nll, reg \ No newline at end of file + return nll, reg diff --git a/train.py b/train.py index e3d1343..fedf03b 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ import os import time import argparse import math -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Literal, Sequence from pathlib import Path @@ -17,14 +17,15 @@ from tqdm import tqdm from dataset import HealthDataset, health_collate_fn from model import DelphiFork, SapDelphi -from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt +from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt @dataclass class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' - loss_type: Literal['exponential', 'weibull'] = 'weibull' + loss_type: Literal['exponential', 'weibull', + 'piecewise_exponential'] = 'weibull' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 @@ -32,6 +33,9 @@ class TrainConfig: n_layer: int = 12 pdrop: float = 0.1 lambda_reg: float = 1e-4 + bin_edges: Sequence[float] = field( + default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0] + ) # SapDelphi specific pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" # Data Parameters @@ -58,7 +62,7 @@ def parse_args() -> TrainConfig: parser.add_argument("--model_type", type=str, choices=[ 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") parser.add_argument("--loss_type", type=str, choices=[ - 'exponential', 'weibull'], default='weibull', help="Type of loss function to use.") + 'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") parser.add_argument("--n_embd", type=int, default=120, @@ -163,6 +167,12 @@ class Trainer: lambda_reg=cfg.lambda_reg, ).to(self.device) n_dim = 1 + elif cfg.loss_type == "piecewise_exponential": + self.criterion = PiecewiseExponentialLoss( + bin_edges=cfg.bin_edges, + lambda_reg=cfg.lambda_reg, + ).to(self.device) + n_dim = len(cfg.bin_edges) - 1 elif cfg.loss_type == "weibull": self.criterion = WeibullNLLLoss( lambda_reg=cfg.lambda_reg, @@ -348,12 +358,7 @@ class Trainer: dt, reduction="none", ) - finite_mask = torch.isfinite(nll_vec) - if not finite_mask.any(): - continue - nll_vec = nll_vec[finite_mask] nll = nll_vec.mean() - loss = nll + reg batch_count += 1 running_nll += nll.item()