Add PiecewiseExponentialLoss class and update TrainConfig for new loss type

This commit is contained in:
2026-01-08 12:45:31 +08:00
parent 7c36f7a007
commit 06a01d2893
2 changed files with 144 additions and 10 deletions

131
losses.py
View File

@@ -132,6 +132,135 @@ class ExponentialNLLLoss(nn.Module):
return nll, reg return nll, reg
class PiecewiseExponentialLoss(nn.Module):
"""Piecewise-constant competing risks exponential likelihood.
Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0).
Within each bin b, hazards are constant and parameterized as:
hazards = softplus(logits) + eps with logits shape (M, K, B)
For each sample i, dt is bucketized to bin b* and the NLL is:
nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du
The integral is computed in closed form by summing full bins plus the partial bin b*.
"""
def __init__(
self,
bin_edges: Sequence[float],
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
if bin_edges[0] != 0:
raise ValueError("bin_edges must start at 0")
for i in range(1, len(bin_edges)):
if not (bin_edges[i] > bin_edges[i - 1]):
raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps)
self.lambda_reg = float(lambda_reg)
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
self.register_buffer("bin_edges", edges, persistent=False)
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.dim() != 3:
raise ValueError("logits must have shape (M, K, B)")
M, K, B = logits.shape
if self.bin_edges.numel() != B + 1:
raise ValueError(
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
)
device = logits.device
dt = dt.to(device=device)
target_events = target_events.to(device=device)
# Build a per-sample finite mask to avoid NaN/Inf propagation.
logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1)
dt_finite = torch.isfinite(dt)
target_finite = torch.isfinite(target_events)
finite_mask = logits_finite & dt_finite & target_finite
nll_full = logits.new_zeros((M,))
if not finite_mask.any():
nll_out = nll_full if reduction == "none" else logits.new_zeros(())
reg_out = logits.new_zeros(())
return nll_out, reg_out
idx = finite_mask.nonzero(as_tuple=False).squeeze(1)
logits_v = logits[idx]
target_v = target_events[idx].to(torch.long)
dt_v = dt[idx].to(torch.float32)
# Clamp dt into [eps, max_edge) to keep bucket indices valid.
eps = self.eps
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
hazards = F.softplus(logits_v) + eps # (Mv, K, B)
total_hazard = hazards.sum(dim=1) # (Mv, B)
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
widths = edges[1:] - edges[:-1] # (B,)
# Bin index b* in [0, B-1]. boundaries are edges[1:] (length B).
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,)
b_star = torch.clamp(b_star, min=0, max=B - 1)
ar = torch.arange(logits_v.size(0), device=device)
hazard_event = hazards[ar, target_v, b_star] # (Mv,)
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
weighted = total_hazard * widths.unsqueeze(0) # (Mv, B)
cum = weighted.cumsum(dim=1) # (Mv, B)
full_bins_int = torch.zeros_like(dt_v)
has_full = b_star > 0
if has_full.any():
full_bins_int[has_full] = cum.gather(
1, (b_star[has_full] - 1).unsqueeze(1)
).squeeze(1)
edge_left = edges[b_star] # (Mv,)
partial = total_hazard.gather(
1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left)
integral = full_bins_int + partial
nll_v = -torch.log(hazard_event) + integral
nll_full[idx] = nll_v
if reduction == "none":
nll_out = nll_full
elif reduction == "sum":
nll_out = nll_v.sum()
elif reduction == "mean":
nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(())
else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = logits.new_zeros(())
if self.lambda_reg != 0.0:
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
return nll_out, reg
class WeibullNLLLoss(nn.Module): class WeibullNLLLoss(nn.Module):
""" """
Weibull hazard in t. Weibull hazard in t.
@@ -207,4 +336,4 @@ class WeibullNLLLoss(nn.Module):
(torch.log(scales + eps) ** 2).mean() + (torch.log(scales + eps) ** 2).mean() +
(torch.log(shapes + eps) ** 2).mean() (torch.log(shapes + eps) ** 2).mean()
) )
return nll, reg return nll, reg

View File

@@ -3,7 +3,7 @@ import os
import time import time
import argparse import argparse
import math import math
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence from typing import Literal, Sequence
from pathlib import Path from pathlib import Path
@@ -17,14 +17,15 @@ from tqdm import tqdm
from dataset import HealthDataset, health_collate_fn from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi from model import DelphiFork, SapDelphi
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
@dataclass @dataclass
class TrainConfig: class TrainConfig:
# Model Parameters # Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull'] = 'weibull' loss_type: Literal['exponential', 'weibull',
'piecewise_exponential'] = 'weibull'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False full_cov: bool = False
n_embd: int = 120 n_embd: int = 120
@@ -32,6 +33,9 @@ class TrainConfig:
n_layer: int = 12 n_layer: int = 12
pdrop: float = 0.1 pdrop: float = 0.1
lambda_reg: float = 1e-4 lambda_reg: float = 1e-4
bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
)
# SapDelphi specific # SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
# Data Parameters # Data Parameters
@@ -58,7 +62,7 @@ def parse_args() -> TrainConfig:
parser.add_argument("--model_type", type=str, choices=[ parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument("--loss_type", type=str, choices=[ parser.add_argument("--loss_type", type=str, choices=[
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.") 'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[ parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120, parser.add_argument("--n_embd", type=int, default=120,
@@ -163,6 +167,12 @@ class Trainer:
lambda_reg=cfg.lambda_reg, lambda_reg=cfg.lambda_reg,
).to(self.device) ).to(self.device)
n_dim = 1 n_dim = 1
elif cfg.loss_type == "piecewise_exponential":
self.criterion = PiecewiseExponentialLoss(
bin_edges=cfg.bin_edges,
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = len(cfg.bin_edges) - 1
elif cfg.loss_type == "weibull": elif cfg.loss_type == "weibull":
self.criterion = WeibullNLLLoss( self.criterion = WeibullNLLLoss(
lambda_reg=cfg.lambda_reg, lambda_reg=cfg.lambda_reg,
@@ -348,12 +358,7 @@ class Trainer:
dt, dt,
reduction="none", reduction="none",
) )
finite_mask = torch.isfinite(nll_vec)
if not finite_mask.any():
continue
nll_vec = nll_vec[finite_mask]
nll = nll_vec.mean() nll = nll_vec.mean()
loss = nll + reg loss = nll + reg
batch_count += 1 batch_count += 1
running_nll += nll.item() running_nll += nll.item()