Add PiecewiseExponentialLoss class and update TrainConfig for new loss type
This commit is contained in:
129
losses.py
129
losses.py
@@ -132,6 +132,135 @@ class ExponentialNLLLoss(nn.Module):
|
|||||||
return nll, reg
|
return nll, reg
|
||||||
|
|
||||||
|
|
||||||
|
class PiecewiseExponentialLoss(nn.Module):
|
||||||
|
"""Piecewise-constant competing risks exponential likelihood.
|
||||||
|
|
||||||
|
Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0).
|
||||||
|
Within each bin b, hazards are constant and parameterized as:
|
||||||
|
|
||||||
|
hazards = softplus(logits) + eps with logits shape (M, K, B)
|
||||||
|
|
||||||
|
For each sample i, dt is bucketized to bin b* and the NLL is:
|
||||||
|
|
||||||
|
nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du
|
||||||
|
|
||||||
|
The integral is computed in closed form by summing full bins plus the partial bin b*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bin_edges: Sequence[float],
|
||||||
|
eps: float = 1e-6,
|
||||||
|
lambda_reg: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if len(bin_edges) < 2:
|
||||||
|
raise ValueError("bin_edges must have length >= 2")
|
||||||
|
if bin_edges[0] != 0:
|
||||||
|
raise ValueError("bin_edges must start at 0")
|
||||||
|
for i in range(1, len(bin_edges)):
|
||||||
|
if not (bin_edges[i] > bin_edges[i - 1]):
|
||||||
|
raise ValueError("bin_edges must be strictly increasing")
|
||||||
|
|
||||||
|
self.eps = float(eps)
|
||||||
|
self.lambda_reg = float(lambda_reg)
|
||||||
|
|
||||||
|
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
||||||
|
self.register_buffer("bin_edges", edges, persistent=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
target_events: torch.Tensor,
|
||||||
|
dt: torch.Tensor,
|
||||||
|
reduction: str = "mean",
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if logits.dim() != 3:
|
||||||
|
raise ValueError("logits must have shape (M, K, B)")
|
||||||
|
|
||||||
|
M, K, B = logits.shape
|
||||||
|
if self.bin_edges.numel() != B + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
|
||||||
|
)
|
||||||
|
|
||||||
|
device = logits.device
|
||||||
|
dt = dt.to(device=device)
|
||||||
|
target_events = target_events.to(device=device)
|
||||||
|
|
||||||
|
# Build a per-sample finite mask to avoid NaN/Inf propagation.
|
||||||
|
logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1)
|
||||||
|
dt_finite = torch.isfinite(dt)
|
||||||
|
target_finite = torch.isfinite(target_events)
|
||||||
|
finite_mask = logits_finite & dt_finite & target_finite
|
||||||
|
|
||||||
|
nll_full = logits.new_zeros((M,))
|
||||||
|
|
||||||
|
if not finite_mask.any():
|
||||||
|
nll_out = nll_full if reduction == "none" else logits.new_zeros(())
|
||||||
|
reg_out = logits.new_zeros(())
|
||||||
|
return nll_out, reg_out
|
||||||
|
|
||||||
|
idx = finite_mask.nonzero(as_tuple=False).squeeze(1)
|
||||||
|
logits_v = logits[idx]
|
||||||
|
target_v = target_events[idx].to(torch.long)
|
||||||
|
dt_v = dt[idx].to(torch.float32)
|
||||||
|
|
||||||
|
# Clamp dt into [eps, max_edge) to keep bucket indices valid.
|
||||||
|
eps = self.eps
|
||||||
|
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
|
||||||
|
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
|
||||||
|
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
|
||||||
|
|
||||||
|
hazards = F.softplus(logits_v) + eps # (Mv, K, B)
|
||||||
|
total_hazard = hazards.sum(dim=1) # (Mv, B)
|
||||||
|
|
||||||
|
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
|
||||||
|
widths = edges[1:] - edges[:-1] # (B,)
|
||||||
|
|
||||||
|
# Bin index b* in [0, B-1]. boundaries are edges[1:] (length B).
|
||||||
|
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,)
|
||||||
|
b_star = torch.clamp(b_star, min=0, max=B - 1)
|
||||||
|
|
||||||
|
ar = torch.arange(logits_v.size(0), device=device)
|
||||||
|
hazard_event = hazards[ar, target_v, b_star] # (Mv,)
|
||||||
|
|
||||||
|
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
||||||
|
weighted = total_hazard * widths.unsqueeze(0) # (Mv, B)
|
||||||
|
cum = weighted.cumsum(dim=1) # (Mv, B)
|
||||||
|
full_bins_int = torch.zeros_like(dt_v)
|
||||||
|
has_full = b_star > 0
|
||||||
|
if has_full.any():
|
||||||
|
full_bins_int[has_full] = cum.gather(
|
||||||
|
1, (b_star[has_full] - 1).unsqueeze(1)
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
edge_left = edges[b_star] # (Mv,)
|
||||||
|
partial = total_hazard.gather(
|
||||||
|
1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left)
|
||||||
|
integral = full_bins_int + partial
|
||||||
|
|
||||||
|
nll_v = -torch.log(hazard_event) + integral
|
||||||
|
nll_full[idx] = nll_v
|
||||||
|
|
||||||
|
if reduction == "none":
|
||||||
|
nll_out = nll_full
|
||||||
|
elif reduction == "sum":
|
||||||
|
nll_out = nll_v.sum()
|
||||||
|
elif reduction == "mean":
|
||||||
|
nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(())
|
||||||
|
else:
|
||||||
|
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||||
|
|
||||||
|
reg = logits.new_zeros(())
|
||||||
|
|
||||||
|
if self.lambda_reg != 0.0:
|
||||||
|
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
|
||||||
|
|
||||||
|
return nll_out, reg
|
||||||
|
|
||||||
|
|
||||||
class WeibullNLLLoss(nn.Module):
|
class WeibullNLLLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Weibull hazard in t.
|
Weibull hazard in t.
|
||||||
|
|||||||
23
train.py
23
train.py
@@ -3,7 +3,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Literal, Sequence
|
from typing import Literal, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -17,14 +17,15 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from dataset import HealthDataset, health_collate_fn
|
from dataset import HealthDataset, health_collate_fn
|
||||||
from model import DelphiFork, SapDelphi
|
from model import DelphiFork, SapDelphi
|
||||||
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||||
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
loss_type: Literal['exponential', 'weibull',
|
||||||
|
'piecewise_exponential'] = 'weibull'
|
||||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||||
full_cov: bool = False
|
full_cov: bool = False
|
||||||
n_embd: int = 120
|
n_embd: int = 120
|
||||||
@@ -32,6 +33,9 @@ class TrainConfig:
|
|||||||
n_layer: int = 12
|
n_layer: int = 12
|
||||||
pdrop: float = 0.1
|
pdrop: float = 0.1
|
||||||
lambda_reg: float = 1e-4
|
lambda_reg: float = 1e-4
|
||||||
|
bin_edges: Sequence[float] = field(
|
||||||
|
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
|
||||||
|
)
|
||||||
# SapDelphi specific
|
# SapDelphi specific
|
||||||
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
@@ -58,7 +62,7 @@ def parse_args() -> TrainConfig:
|
|||||||
parser.add_argument("--model_type", type=str, choices=[
|
parser.add_argument("--model_type", type=str, choices=[
|
||||||
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
||||||
parser.add_argument("--loss_type", type=str, choices=[
|
parser.add_argument("--loss_type", type=str, choices=[
|
||||||
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
|
||||||
parser.add_argument("--age_encoder", type=str, choices=[
|
parser.add_argument("--age_encoder", type=str, choices=[
|
||||||
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||||
parser.add_argument("--n_embd", type=int, default=120,
|
parser.add_argument("--n_embd", type=int, default=120,
|
||||||
@@ -163,6 +167,12 @@ class Trainer:
|
|||||||
lambda_reg=cfg.lambda_reg,
|
lambda_reg=cfg.lambda_reg,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
n_dim = 1
|
n_dim = 1
|
||||||
|
elif cfg.loss_type == "piecewise_exponential":
|
||||||
|
self.criterion = PiecewiseExponentialLoss(
|
||||||
|
bin_edges=cfg.bin_edges,
|
||||||
|
lambda_reg=cfg.lambda_reg,
|
||||||
|
).to(self.device)
|
||||||
|
n_dim = len(cfg.bin_edges) - 1
|
||||||
elif cfg.loss_type == "weibull":
|
elif cfg.loss_type == "weibull":
|
||||||
self.criterion = WeibullNLLLoss(
|
self.criterion = WeibullNLLLoss(
|
||||||
lambda_reg=cfg.lambda_reg,
|
lambda_reg=cfg.lambda_reg,
|
||||||
@@ -348,12 +358,7 @@ class Trainer:
|
|||||||
dt,
|
dt,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
)
|
)
|
||||||
finite_mask = torch.isfinite(nll_vec)
|
|
||||||
if not finite_mask.any():
|
|
||||||
continue
|
|
||||||
nll_vec = nll_vec[finite_mask]
|
|
||||||
nll = nll_vec.mean()
|
nll = nll_vec.mean()
|
||||||
|
|
||||||
loss = nll + reg
|
loss = nll + reg
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
running_nll += nll.item()
|
running_nll += nll.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user