Add PiecewiseExponentialLoss class and update TrainConfig for new loss type
This commit is contained in:
23
train.py
23
train.py
@@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
@@ -17,14 +17,15 @@ from tqdm import tqdm
|
||||
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from model import DelphiFork, SapDelphi
|
||||
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
||||
loss_type: Literal['exponential', 'weibull',
|
||||
'piecewise_exponential'] = 'weibull'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -32,6 +33,9 @@ class TrainConfig:
|
||||
n_layer: int = 12
|
||||
pdrop: float = 0.1
|
||||
lambda_reg: float = 1e-4
|
||||
bin_edges: Sequence[float] = field(
|
||||
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
|
||||
)
|
||||
# SapDelphi specific
|
||||
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
||||
# Data Parameters
|
||||
@@ -58,7 +62,7 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument("--model_type", type=str, choices=[
|
||||
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
||||
parser.add_argument("--loss_type", type=str, choices=[
|
||||
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
||||
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||
parser.add_argument("--n_embd", type=int, default=120,
|
||||
@@ -163,6 +167,12 @@ class Trainer:
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = 1
|
||||
elif cfg.loss_type == "piecewise_exponential":
|
||||
self.criterion = PiecewiseExponentialLoss(
|
||||
bin_edges=cfg.bin_edges,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = len(cfg.bin_edges) - 1
|
||||
elif cfg.loss_type == "weibull":
|
||||
self.criterion = WeibullNLLLoss(
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
@@ -348,12 +358,7 @@ class Trainer:
|
||||
dt,
|
||||
reduction="none",
|
||||
)
|
||||
finite_mask = torch.isfinite(nll_vec)
|
||||
if not finite_mask.any():
|
||||
continue
|
||||
nll_vec = nll_vec[finite_mask]
|
||||
nll = nll_vec.mean()
|
||||
|
||||
loss = nll + reg
|
||||
batch_count += 1
|
||||
running_nll += nll.item()
|
||||
|
||||
Reference in New Issue
Block a user