Add Piecewise Exponential CIF Loss and update model evaluation for PWE
This commit is contained in:
29
train.py
29
train.py
@@ -1,4 +1,4 @@
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
@@ -22,7 +22,8 @@ from typing import Literal, Sequence
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
|
||||
loss_type: Literal['exponential',
|
||||
'discrete_time_cif', 'pwe_cif'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -64,7 +65,7 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif'],
|
||||
choices=['exponential', 'discrete_time_cif', 'pwe_cif'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
@@ -205,6 +206,28 @@ class Trainer:
|
||||
).to(self.device)
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
elif cfg.loss_type == "pwe_cif":
|
||||
# Piecewise-exponential (PWE) requires a FINITE last edge.
|
||||
# If cfg.bin_edges ends with +inf (default), drop it and train up to the last finite edge.
|
||||
pwe_edges = [float(x)
|
||||
for x in cfg.bin_edges if math.isfinite(float(x))]
|
||||
if len(pwe_edges) < 2:
|
||||
raise ValueError(
|
||||
"pwe_cif requires at least 2 finite bin edges (including 0). "
|
||||
f"Got bin_edges={list(cfg.bin_edges)}"
|
||||
)
|
||||
if pwe_edges[0] != 0.0:
|
||||
raise ValueError(
|
||||
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}"
|
||||
)
|
||||
|
||||
self.criterion = PiecewiseExponentialCIFNLLLoss(
|
||||
bin_edges=pwe_edges,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_bins = len(pwe_edges) - 1
|
||||
# logits shape (M, K, n_bins)
|
||||
out_dims = [dataset.n_disease, n_bins]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user