Add Piecewise Exponential CIF Loss and update model evaluation for PWE

This commit is contained in:
2026-01-15 11:36:24 +08:00
parent d8b322cbee
commit 2f46acf2bd
3 changed files with 275 additions and 6 deletions

View File

@@ -1,4 +1,4 @@
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
@@ -22,7 +22,8 @@ from typing import Literal, Sequence
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
loss_type: Literal['exponential',
'discrete_time_cif', 'pwe_cif'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -64,7 +65,7 @@ def parse_args() -> TrainConfig:
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif'],
choices=['exponential', 'discrete_time_cif', 'pwe_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
@@ -205,6 +206,28 @@ class Trainer:
).to(self.device)
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "pwe_cif":
# Piecewise-exponential (PWE) requires a FINITE last edge.
# If cfg.bin_edges ends with +inf (default), drop it and train up to the last finite edge.
pwe_edges = [float(x)
for x in cfg.bin_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0). "
f"Got bin_edges={list(cfg.bin_edges)}"
)
if pwe_edges[0] != 0.0:
raise ValueError(
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}"
)
self.criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges,
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_bins = len(pwe_edges) - 1
# logits shape (M, K, n_bins)
out_dims = [dataset.n_disease, n_bins]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")