Add LogNormalBasisHazardLoss implementation and update training configuration

This commit is contained in:
2026-01-13 15:59:20 +08:00
parent d8b322cbee
commit 1df02d85d7
5 changed files with 431 additions and 15 deletions

View File

@@ -1,4 +1,4 @@
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
@@ -13,16 +13,16 @@ import os
import time
import argparse
import math
import sys
from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence
from typing import Literal, Optional, Sequence
@dataclass
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
loss_type: Literal['exponential', 'discrete_time_cif',
'lognormal_basis_hazard'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -34,6 +34,15 @@ class TrainConfig:
default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
)
# LogNormalBasisHazardLoss specific
lognormal_centers: Optional[Sequence[float]] = field(
default_factory=list) # mu_r in log-time
loss_eps: float = 1e-8
bandwidth_init: float = 0.7
bandwidth_min: float = 1e-3
bandwidth_max: float = 10.0
lambda_sigma_reg: float = 1e-4
sigma_reg_target: Optional[float] = None
rank: int = 16
# SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
@@ -64,7 +73,7 @@ def parse_args() -> TrainConfig:
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif'],
choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
@@ -79,6 +88,24 @@ def parse_args() -> TrainConfig:
help="Dropout probability.")
parser.add_argument("--lambda_reg", type=float,
default=1e-4, help="Regularization weight.")
parser.add_argument(
"--lognormal_centers",
type=float,
nargs='*',
default=None,
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
parser.add_argument("--loss_eps", type=float, default=1e-8,
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
parser.add_argument("--bandwidth_init", type=float, default=0.7,
help="Initial sigma for LogNormalBasisHazardLoss.")
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
parser.add_argument("--bandwidth_max", type=float, default=10.0,
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
parser.add_argument("--sigma_reg_target", type=float, default=None,
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
parser.add_argument("--rank", type=int, default=16,
help="Rank for low-rank parameterization (if applicable).")
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
@@ -118,7 +145,36 @@ def parse_args() -> TrainConfig:
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help="Device to use for training.")
args = parser.parse_args()
return TrainConfig(**vars(args))
cfg = TrainConfig(**vars(args))
# If lognormal_centers are not provided, derive a reasonable default from bin_edges.
if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0:
edges = [float(x) for x in cfg.bin_edges]
finite = [e for e in edges if math.isfinite(e)]
if len(finite) < 2:
raise ValueError(
"bin_edges must contain at least two finite edges to derive lognormal_centers")
e1 = finite[1]
t_min = (e1 * 1e-3) if e1 > 0 else 1e-12
# Build one center per bin (including the +inf last bin if present).
centers: list[float] = []
for i in range(1, len(edges)):
left = float(edges[i - 1])
right = float(edges[i])
if i == 1 and left <= 0.0:
left_pos = t_min
else:
left_pos = max(left, t_min)
if math.isinf(right):
mid = max(left_pos * 2.0, left_pos + 1e-6)
else:
right_pos = max(right, t_min)
mid = math.sqrt(left_pos * right_pos)
centers.append(math.log(max(mid, t_min)))
cfg.lognormal_centers = centers
return cfg
def get_num_params(model: nn.Module) -> int:
@@ -205,6 +261,23 @@ class Trainer:
).to(self.device)
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "lognormal_basis_hazard":
r = len(cfg.lognormal_centers)
if r <= 0:
raise ValueError(
"lognormal_centers must be non-empty for lognormal_basis_hazard")
self.criterion = LogNormalBasisHazardLoss(
bin_edges=cfg.bin_edges,
centers=cfg.lognormal_centers,
eps=cfg.loss_eps,
bandwidth_init=cfg.bandwidth_init,
bandwidth_min=cfg.bandwidth_min,
bandwidth_max=cfg.bandwidth_max,
lambda_sigma_reg=cfg.lambda_sigma_reg,
sigma_reg_target=cfg.sigma_reg_target,
).to(self.device)
# logits shape (M, 1 + J*R)
out_dims = [1 + dataset.n_disease * r]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")