Add PiecewiseExponentialLoss class and update TrainConfig for new loss type

This commit is contained in:
2026-01-08 12:45:31 +08:00
parent 7c36f7a007
commit 06a01d2893
2 changed files with 144 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ import os
import time
import argparse
import math
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence
from pathlib import Path
@@ -17,14 +17,15 @@ from tqdm import tqdm
from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
@dataclass
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull'] = 'weibull'
loss_type: Literal['exponential', 'weibull',
'piecewise_exponential'] = 'weibull'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -32,6 +33,9 @@ class TrainConfig:
n_layer: int = 12
pdrop: float = 0.1
lambda_reg: float = 1e-4
bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
)
# SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
# Data Parameters
@@ -58,7 +62,7 @@ def parse_args() -> TrainConfig:
parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument("--loss_type", type=str, choices=[
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120,
@@ -163,6 +167,12 @@ class Trainer:
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 1
elif cfg.loss_type == "piecewise_exponential":
self.criterion = PiecewiseExponentialLoss(
bin_edges=cfg.bin_edges,
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = len(cfg.bin_edges) - 1
elif cfg.loss_type == "weibull":
self.criterion = WeibullNLLLoss(
lambda_reg=cfg.lambda_reg,
@@ -348,12 +358,7 @@ class Trainer:
dt,
reduction="none",
)
finite_mask = torch.isfinite(nll_vec)
if not finite_mask.any():
continue
nll_vec = nll_vec[finite_mask]
nll = nll_vec.mean()
loss = nll + reg
batch_count += 1
running_nll += nll.item()