Add LogNormalBasisHazardLoss implementation and update training configuration

This commit is contained in:
2026-01-13 15:59:20 +08:00
parent d8b322cbee
commit 1df02d85d7
5 changed files with 431 additions and 15 deletions

View File

@@ -8,7 +8,7 @@ import sys
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
class ModelSpec: class ModelSpec:
name: str name: str
model_type: str # delphi_fork | sap_delphi model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
full_cov: bool full_cov: bool
checkpoint_path: str checkpoint_path: str
@@ -420,6 +420,72 @@ def cifs_from_discrete_time_logits(
return cif, survival return cif, survival
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
def cifs_from_lognormal_basis_logits(
logits: torch.Tensor,
centers: Sequence[float],
sigma: torch.Tensor,
taus: Sequence[float],
*,
bin_edges: Optional[Sequence[float]] = None,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
centers: length R, in log-time.
sigma: scalar tensor (already clamped) in log-time units.
taus: horizons in the same units as training bin_edges (years).
"""
if logits.ndim != 2:
raise ValueError("Expected logits shape (B, 1+K*R)")
if sigma.ndim != 0:
raise ValueError("sigma must be a scalar tensor")
device = logits.device
dtype = logits.dtype
centers_t = torch.tensor([float(x)
for x in centers], device=device, dtype=dtype)
r = int(centers_t.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0 or (jr % r) != 0:
raise ValueError("logits.shape[1]-1 must be divisible by R")
k = jr // r
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
if bin_edges is not None:
edges = [float(x) for x in bin_edges]
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
taus_t = torch.clamp(taus_t, min=t_min_t)
log_tau = torch.log(taus_t) # (H,)
# (H,R)
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
cdf = _normal_cdf_stable(z)
w_all = torch.softmax(logits, dim=-1)
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
if not return_survival:
return cif
survival = 1.0 - cif.sum(dim=1) # (B,H)
survival = torch.clamp(survival, min=0.0, max=1.0)
return cif, survival
# ============================================================ # ============================================================
# CIF integrity checks # CIF integrity checks
# ============================================================ # ============================================================
@@ -1192,15 +1258,31 @@ def instantiate_model_and_head(
dataset: HealthDataset, dataset: HealthDataset,
device: str, device: str,
checkpoint_path: str = "", checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: ) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
model_type = str(cfg["model_type"]) model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"]) loss_type = str(cfg["loss_type"])
loss_params: Dict[str, Any] = {}
if loss_type == "exponential": if loss_type == "exponential":
out_dims = [dataset.n_disease] out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif": elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)] out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "lognormal_basis_hazard":
centers = cfg.get("lognormal_centers", None)
if centers is None:
centers = cfg.get("centers", None)
if not isinstance(centers, list) or len(centers) == 0:
raise ValueError(
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
)
out_dims = [1 + dataset.n_disease * len(centers)]
loss_params["centers"] = centers
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
else: else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
@@ -1249,7 +1331,7 @@ def instantiate_model_and_head(
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges return backbone, head, loss_type, bin_edges, loss_params
@torch.no_grad() @torch.no_grad()
@@ -1258,6 +1340,7 @@ def predict_cifs_for_model(
head: torch.nn.Module, head: torch.nn.Module,
loss_type: str, loss_type: str,
bin_edges: Sequence[float], bin_edges: Sequence[float],
loss_params: Dict[str, Any],
loader: DataLoader, loader: DataLoader,
device: str, device: str,
offset_years: float, offset_years: float,
@@ -1316,6 +1399,20 @@ def predict_cifs_for_model(
cif_full, survival = cifs_from_discrete_time_logits( cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H) # (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True) logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "lognormal_basis_hazard":
centers = loss_params.get("centers", None)
sigma = loss_params.get("sigma", None)
if centers is None or sigma is None:
raise ValueError(
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_logits(
logits,
centers=centers,
sigma=sigma,
taus=eval_horizons,
bin_edges=bin_edges,
return_survival=True,
)
else: else:
raise ValueError(f"Unsupported loss_type: {loss_type}") raise ValueError(f"Unsupported loss_type: {loss_type}")
@@ -1823,12 +1920,27 @@ def main() -> int:
collate_fn=health_collate_fn, collate_fn=health_collate_fn,
) )
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device) ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True) backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True)
if loss_type == "lognormal_basis_hazard":
crit_state = ckpt.get("criterion_state_dict", {})
log_sigma = crit_state.get("log_sigma", None)
if isinstance(log_sigma, torch.Tensor):
log_sigma_t = log_sigma.to(device=args.device)
sigma = torch.exp(log_sigma_t)
else:
sigma = torch.tensor(float(loss_params.get(
"bandwidth_init", 0.7)), device=args.device)
bmin = float(loss_params.get("bandwidth_min", 1e-3))
bmax = float(loss_params.get("bandwidth_max", 10.0))
sigma = torch.clamp(sigma, min=bmin, max=bmax)
loss_params["sigma"] = sigma
( (
cif_full, cif_full,
survival, survival,
@@ -1839,6 +1951,7 @@ def main() -> int:
head, head,
loss_type, loss_type,
bin_edges, bin_edges,
loss_params,
loader, loader,
args.device, args.device,
args.offset_years, args.offset_years,

232
losses.py
View File

@@ -1,5 +1,5 @@
import math import math
from typing import Optional, Sequence, Tuple from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -258,3 +258,233 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
F.nll_loss(logp_at_event_bin, target_events, reduction="mean") F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
return nll, reg return nll, reg
class LogNormalBasisHazardLoss(nn.Module):
"""Event-only competing risks loss using lognormal basis (Gaussian on log-time).
This loss models cause-specific CIF as a mixture of lognormal basis CDFs:
F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma)
Training uses *bin probability mass* (Delta CIF / interval mass). There is
**no censoring**: every sample is an observed event with a valid cause.
Logits interface:
logits: (B, 1 + J*R)
logits[:, 0] -> w0 (survival mass / never-event)
logits[:, 1:] -> flattened (j,r) in row-major order: j then r
index = 1 + j*R + r
Forward interface (must match):
forward(logits, target_events, dt, reduction)
"""
def __init__(
self,
bin_edges: Sequence[float],
centers: Sequence[float],
*,
eps: float = 1e-8,
bandwidth_init: float = 0.5,
bandwidth_min: float = 1e-3,
bandwidth_max: float = 10.0,
lambda_sigma_reg: float = 0.0,
sigma_reg_target: Optional[float] = None,
return_dict: bool = False,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
# allow last edge to be +inf
for i in range(1, len(bin_edges)):
prev = float(bin_edges[i - 1])
cur = float(bin_edges[i])
if math.isinf(prev):
raise ValueError(
"bin_edges cannot have +inf except possibly as the last edge")
if i == len(bin_edges) - 1 and math.isinf(cur):
if not (cur > prev):
raise ValueError("bin_edges must be strictly increasing")
else:
if not (cur > prev):
raise ValueError("bin_edges must be strictly increasing")
if float(bin_edges[0]) < 0.0:
raise ValueError("bin_edges[0] must be >= 0")
if len(centers) < 1:
raise ValueError("centers must have length >= 1")
self.eps = float(eps)
self.bandwidth_min = float(bandwidth_min)
self.bandwidth_max = float(bandwidth_max)
self.lambda_sigma_reg = float(lambda_sigma_reg)
self.sigma_reg_target = None if sigma_reg_target is None else float(
sigma_reg_target)
self.bandwidth_init = float(bandwidth_init)
self.return_dict = bool(return_dict)
self.register_buffer(
"bin_edges",
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"centers",
torch.tensor([float(x) for x in centers], dtype=torch.float32),
persistent=False,
)
if self.bandwidth_init <= 0:
raise ValueError("bandwidth_init must be > 0")
self.log_sigma = nn.Parameter(torch.tensor(
math.log(self.bandwidth_init), dtype=torch.float32))
@staticmethod
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
# Stable normal CDF via erf.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
@staticmethod
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
# Stable normal survival function via erfc.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * torch.erfc(z / math.sqrt(2.0))
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]:
if logits.ndim != 2:
raise ValueError(
f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}")
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError("target_events and dt must be 1D tensors")
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch among logits, target_events, dt")
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
device = logits.device
dtype = logits.dtype
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
centers = self.centers.to(device=device, dtype=dtype)
bsz = logits.shape[0]
r = int(centers.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0:
raise ValueError(
"logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)")
if jr % r != 0:
raise ValueError(
f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}")
j = jr // r
# 1) Stable global weights (includes w0).
w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R)
w0 = w_all[:, 0]
w = w_all[:, 1:].view(bsz, j, r)
# 2) Determine event bin index.
k = int(bin_edges.numel() - 1)
if k < 1:
raise ValueError("bin_edges must define at least one bin")
# v2: dt is always continuous time (float), map to bin via searchsorted.
dt_f = dt.to(device=device, dtype=dtype)
bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1
bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long)
left = bin_edges[bin_idx]
right = bin_edges[bin_idx + 1]
# 3) Stable log(t) clamp.
if float(self.bin_edges[1]) > 0.0:
t_min = float(self.bin_edges[1]) * 1e-6
else:
t_min = 1e-12
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
left_is_zero = left <= 0
# For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly
# when left<=0 (instead of approximating via t_min).
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
is_inf = torch.isinf(right)
# right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value.
right_safe = torch.where(is_inf, left_clamped,
torch.clamp(right, min=t_min_t))
log_right = torch.log(right_safe)
sigma = torch.clamp(self.log_sigma.to(
device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max)
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
z_left = torch.clamp(z_left, -12.0, 12.0)
z_right = torch.clamp(z_right, -12.0, 12.0)
cdf_left = self._normal_cdf(z_left)
# Treat the first-bin left boundary exactly as 0 in CDF.
if left_is_zero.any():
cdf_left = torch.where(
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
cdf_right = self._normal_cdf(z_right)
delta_finite = cdf_right - cdf_left
delta_last = self._normal_sf(z_left)
# If left<=0, SF(left)=1 exactly.
if left_is_zero.any():
delta_last = torch.where(
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
delta_basis = torch.where(
is_inf.unsqueeze(-1), delta_last, delta_finite)
delta_basis = torch.clamp(delta_basis, min=0.0)
# 4) Gather per-sample cause weights and compute event mass.
cause = target_events.to(device=device, dtype=torch.long)
if (cause < 0).any() or (cause >= j).any():
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
b_idx = torch.arange(bsz, device=device)
w_cause = w[b_idx, cause, :] # (B, R)
m = (w_cause * delta_basis).sum(dim=-1)
m = torch.clamp(m, min=self.eps)
nll_vec = -torch.log(m)
if reduction == "mean":
nll: torch.Tensor = nll_vec.mean()
elif reduction == "sum":
nll = nll_vec.sum()
else:
nll = nll_vec
sigma_penalty = torch.zeros((), device=device, dtype=dtype)
if self.lambda_sigma_reg > 0.0:
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
sigma_penalty = (self.log_sigma.to(
device=device, dtype=dtype) - math.log(float(target))) ** 2
reg = sigma_penalty * float(self.lambda_sigma_reg)
if not self.return_dict:
return nll, reg
return {
"nll": nll,
"reg": reg,
"nll_vec": nll_vec,
"sigma": sigma.detach(),
"avg_w0": w0.mean().detach(),
"min_delta_basis": delta_basis.min().detach(),
"mean_m": m.mean().detach(),
"sigma_penalty": sigma_penalty.detach(),
"bin_idx": bin_idx.detach(),
}

View File

@@ -2,7 +2,6 @@ import numpy as np
from typing import Optional, List from typing import Optional, List
from backbones import Block from backbones import Block
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch import torch

View File

@@ -4,7 +4,8 @@ import numpy as np # Numerical operations
# CSV mapping field IDs to human-readable names # CSV mapping field IDs to human-readable names
field_map_file = "field_ids_enriched.csv" field_map_file = "field_ids_enriched.csv"
field_dict = {} # Map original field ID -> new column name # Map original field ID -> new column name
field_dict = {}
tabular_fields = [] # List of tabular feature column names tabular_fields = [] # List of tabular feature column names
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
next(f) # skip header line next(f) # skip header line

View File

@@ -1,4 +1,4 @@
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm from tqdm import tqdm
@@ -13,16 +13,16 @@ import os
import time import time
import argparse import argparse
import math import math
import sys
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence from typing import Literal, Optional, Sequence
@dataclass @dataclass
class TrainConfig: class TrainConfig:
# Model Parameters # Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential' loss_type: Literal['exponential', 'discrete_time_cif',
'lognormal_basis_hazard'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False full_cov: bool = False
n_embd: int = 120 n_embd: int = 120
@@ -34,6 +34,15 @@ class TrainConfig:
default_factory=lambda: [0.0, 0.24, 0.72, default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')] 1.61, 3.84, 10.0, 31.0, float('inf')]
) )
# LogNormalBasisHazardLoss specific
lognormal_centers: Optional[Sequence[float]] = field(
default_factory=list) # mu_r in log-time
loss_eps: float = 1e-8
bandwidth_init: float = 0.7
bandwidth_min: float = 1e-3
bandwidth_max: float = 10.0
lambda_sigma_reg: float = 1e-4
sigma_reg_target: Optional[float] = None
rank: int = 16 rank: int = 16
# SapDelphi specific # SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
@@ -64,7 +73,7 @@ def parse_args() -> TrainConfig:
parser.add_argument( parser.add_argument(
"--loss_type", "--loss_type",
type=str, type=str,
choices=['exponential', 'discrete_time_cif'], choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
default='exponential', default='exponential',
help="Type of loss function to use.") help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[ parser.add_argument("--age_encoder", type=str, choices=[
@@ -79,6 +88,24 @@ def parse_args() -> TrainConfig:
help="Dropout probability.") help="Dropout probability.")
parser.add_argument("--lambda_reg", type=float, parser.add_argument("--lambda_reg", type=float,
default=1e-4, help="Regularization weight.") default=1e-4, help="Regularization weight.")
parser.add_argument(
"--lognormal_centers",
type=float,
nargs='*',
default=None,
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
parser.add_argument("--loss_eps", type=float, default=1e-8,
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
parser.add_argument("--bandwidth_init", type=float, default=0.7,
help="Initial sigma for LogNormalBasisHazardLoss.")
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
parser.add_argument("--bandwidth_max", type=float, default=10.0,
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
parser.add_argument("--sigma_reg_target", type=float, default=None,
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
parser.add_argument("--rank", type=int, default=16, parser.add_argument("--rank", type=int, default=16,
help="Rank for low-rank parameterization (if applicable).") help="Rank for low-rank parameterization (if applicable).")
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy", parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
@@ -118,7 +145,36 @@ def parse_args() -> TrainConfig:
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help="Device to use for training.") help="Device to use for training.")
args = parser.parse_args() args = parser.parse_args()
return TrainConfig(**vars(args))
cfg = TrainConfig(**vars(args))
# If lognormal_centers are not provided, derive a reasonable default from bin_edges.
if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0:
edges = [float(x) for x in cfg.bin_edges]
finite = [e for e in edges if math.isfinite(e)]
if len(finite) < 2:
raise ValueError(
"bin_edges must contain at least two finite edges to derive lognormal_centers")
e1 = finite[1]
t_min = (e1 * 1e-3) if e1 > 0 else 1e-12
# Build one center per bin (including the +inf last bin if present).
centers: list[float] = []
for i in range(1, len(edges)):
left = float(edges[i - 1])
right = float(edges[i])
if i == 1 and left <= 0.0:
left_pos = t_min
else:
left_pos = max(left, t_min)
if math.isinf(right):
mid = max(left_pos * 2.0, left_pos + 1e-6)
else:
right_pos = max(right, t_min)
mid = math.sqrt(left_pos * right_pos)
centers.append(math.log(max(mid, t_min)))
cfg.lognormal_centers = centers
return cfg
def get_num_params(model: nn.Module) -> int: def get_num_params(model: nn.Module) -> int:
@@ -205,6 +261,23 @@ class Trainer:
).to(self.device) ).to(self.device)
# logits shape (M, K+1, n_bins+1) # logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)] out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "lognormal_basis_hazard":
r = len(cfg.lognormal_centers)
if r <= 0:
raise ValueError(
"lognormal_centers must be non-empty for lognormal_basis_hazard")
self.criterion = LogNormalBasisHazardLoss(
bin_edges=cfg.bin_edges,
centers=cfg.lognormal_centers,
eps=cfg.loss_eps,
bandwidth_init=cfg.bandwidth_init,
bandwidth_min=cfg.bandwidth_min,
bandwidth_max=cfg.bandwidth_max,
lambda_sigma_reg=cfg.lambda_sigma_reg,
sigma_reg_target=cfg.sigma_reg_target,
).to(self.device)
# logits shape (M, 1 + J*R)
out_dims = [1 + dataset.n_disease * r]
else: else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}") raise ValueError(f"Unsupported loss type: {cfg.loss_type}")