Add LogNormalBasisHazardLoss implementation and update training configuration
This commit is contained in:
@@ -8,7 +8,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
|
||||
class ModelSpec:
|
||||
name: str
|
||||
model_type: str # delphi_fork | sap_delphi
|
||||
loss_type: str # exponential | discrete_time_cif
|
||||
loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
|
||||
full_cov: bool
|
||||
checkpoint_path: str
|
||||
|
||||
@@ -420,6 +420,72 @@ def cifs_from_discrete_time_logits(
|
||||
return cif, survival
|
||||
|
||||
|
||||
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def cifs_from_lognormal_basis_logits(
|
||||
logits: torch.Tensor,
|
||||
centers: Sequence[float],
|
||||
sigma: torch.Tensor,
|
||||
taus: Sequence[float],
|
||||
*,
|
||||
bin_edges: Optional[Sequence[float]] = None,
|
||||
return_survival: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
|
||||
|
||||
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
|
||||
centers: length R, in log-time.
|
||||
sigma: scalar tensor (already clamped) in log-time units.
|
||||
taus: horizons in the same units as training bin_edges (years).
|
||||
"""
|
||||
if logits.ndim != 2:
|
||||
raise ValueError("Expected logits shape (B, 1+K*R)")
|
||||
if sigma.ndim != 0:
|
||||
raise ValueError("sigma must be a scalar tensor")
|
||||
|
||||
device = logits.device
|
||||
dtype = logits.dtype
|
||||
|
||||
centers_t = torch.tensor([float(x)
|
||||
for x in centers], device=device, dtype=dtype)
|
||||
r = int(centers_t.numel())
|
||||
jr = int(logits.shape[1] - 1)
|
||||
if jr <= 0 or (jr % r) != 0:
|
||||
raise ValueError("logits.shape[1]-1 must be divisible by R")
|
||||
k = jr // r
|
||||
|
||||
# Stable t_min clamp (aligns with training loss rule).
|
||||
t_min = 1e-12
|
||||
if bin_edges is not None:
|
||||
edges = [float(x) for x in bin_edges]
|
||||
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
|
||||
t_min = edges[1] * 1e-6
|
||||
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
|
||||
|
||||
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
|
||||
taus_t = torch.clamp(taus_t, min=t_min_t)
|
||||
log_tau = torch.log(taus_t) # (H,)
|
||||
|
||||
# (H,R)
|
||||
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
|
||||
cdf = _normal_cdf_stable(z)
|
||||
|
||||
w_all = torch.softmax(logits, dim=-1)
|
||||
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
|
||||
|
||||
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
survival = 1.0 - cif.sum(dim=1) # (B,H)
|
||||
survival = torch.clamp(survival, min=0.0, max=1.0)
|
||||
return cif, survival
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CIF integrity checks
|
||||
# ============================================================
|
||||
@@ -1192,15 +1258,31 @@ def instantiate_model_and_head(
|
||||
dataset: HealthDataset,
|
||||
device: str,
|
||||
checkpoint_path: str = "",
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
|
||||
model_type = str(cfg["model_type"])
|
||||
loss_type = str(cfg["loss_type"])
|
||||
|
||||
loss_params: Dict[str, Any] = {}
|
||||
|
||||
if loss_type == "exponential":
|
||||
out_dims = [dataset.n_disease]
|
||||
elif loss_type == "discrete_time_cif":
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
centers = cfg.get("lognormal_centers", None)
|
||||
if centers is None:
|
||||
centers = cfg.get("centers", None)
|
||||
if not isinstance(centers, list) or len(centers) == 0:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
|
||||
)
|
||||
out_dims = [1 + dataset.n_disease * len(centers)]
|
||||
loss_params["centers"] = centers
|
||||
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
|
||||
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
|
||||
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
|
||||
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||||
|
||||
@@ -1249,7 +1331,7 @@ def instantiate_model_and_head(
|
||||
|
||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
return backbone, head, loss_type, bin_edges
|
||||
return backbone, head, loss_type, bin_edges, loss_params
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1258,6 +1340,7 @@ def predict_cifs_for_model(
|
||||
head: torch.nn.Module,
|
||||
loss_type: str,
|
||||
bin_edges: Sequence[float],
|
||||
loss_params: Dict[str, Any],
|
||||
loader: DataLoader,
|
||||
device: str,
|
||||
offset_years: float,
|
||||
@@ -1316,6 +1399,20 @@ def predict_cifs_for_model(
|
||||
cif_full, survival = cifs_from_discrete_time_logits(
|
||||
# (B,K,H), (B,H)
|
||||
logits, bin_edges, eval_horizons, return_survival=True)
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
centers = loss_params.get("centers", None)
|
||||
sigma = loss_params.get("sigma", None)
|
||||
if centers is None or sigma is None:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
|
||||
cif_full, survival = cifs_from_lognormal_basis_logits(
|
||||
logits,
|
||||
centers=centers,
|
||||
sigma=sigma,
|
||||
taus=eval_horizons,
|
||||
bin_edges=bin_edges,
|
||||
return_survival=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||
|
||||
@@ -1823,12 +1920,27 @@ def main() -> int:
|
||||
collate_fn=health_collate_fn,
|
||||
)
|
||||
|
||||
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
|
||||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||||
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
|
||||
|
||||
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
|
||||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||||
|
||||
if loss_type == "lognormal_basis_hazard":
|
||||
crit_state = ckpt.get("criterion_state_dict", {})
|
||||
log_sigma = crit_state.get("log_sigma", None)
|
||||
if isinstance(log_sigma, torch.Tensor):
|
||||
log_sigma_t = log_sigma.to(device=args.device)
|
||||
sigma = torch.exp(log_sigma_t)
|
||||
else:
|
||||
sigma = torch.tensor(float(loss_params.get(
|
||||
"bandwidth_init", 0.7)), device=args.device)
|
||||
bmin = float(loss_params.get("bandwidth_min", 1e-3))
|
||||
bmax = float(loss_params.get("bandwidth_max", 10.0))
|
||||
sigma = torch.clamp(sigma, min=bmin, max=bmax)
|
||||
loss_params["sigma"] = sigma
|
||||
|
||||
(
|
||||
cif_full,
|
||||
survival,
|
||||
@@ -1839,6 +1951,7 @@ def main() -> int:
|
||||
head,
|
||||
loss_type,
|
||||
bin_edges,
|
||||
loss_params,
|
||||
loader,
|
||||
args.device,
|
||||
args.offset_years,
|
||||
|
||||
232
losses.py
232
losses.py
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -258,3 +258,233 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
||||
|
||||
return nll, reg
|
||||
|
||||
|
||||
class LogNormalBasisHazardLoss(nn.Module):
|
||||
"""Event-only competing risks loss using lognormal basis (Gaussian on log-time).
|
||||
|
||||
This loss models cause-specific CIF as a mixture of lognormal basis CDFs:
|
||||
|
||||
F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma)
|
||||
|
||||
Training uses *bin probability mass* (Delta CIF / interval mass). There is
|
||||
**no censoring**: every sample is an observed event with a valid cause.
|
||||
|
||||
Logits interface:
|
||||
logits: (B, 1 + J*R)
|
||||
logits[:, 0] -> w0 (survival mass / never-event)
|
||||
logits[:, 1:] -> flattened (j,r) in row-major order: j then r
|
||||
index = 1 + j*R + r
|
||||
|
||||
Forward interface (must match):
|
||||
forward(logits, target_events, dt, reduction)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bin_edges: Sequence[float],
|
||||
centers: Sequence[float],
|
||||
*,
|
||||
eps: float = 1e-8,
|
||||
bandwidth_init: float = 0.5,
|
||||
bandwidth_min: float = 1e-3,
|
||||
bandwidth_max: float = 10.0,
|
||||
lambda_sigma_reg: float = 0.0,
|
||||
sigma_reg_target: Optional[float] = None,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if len(bin_edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2")
|
||||
# allow last edge to be +inf
|
||||
for i in range(1, len(bin_edges)):
|
||||
prev = float(bin_edges[i - 1])
|
||||
cur = float(bin_edges[i])
|
||||
if math.isinf(prev):
|
||||
raise ValueError(
|
||||
"bin_edges cannot have +inf except possibly as the last edge")
|
||||
if i == len(bin_edges) - 1 and math.isinf(cur):
|
||||
if not (cur > prev):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
else:
|
||||
if not (cur > prev):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
if float(bin_edges[0]) < 0.0:
|
||||
raise ValueError("bin_edges[0] must be >= 0")
|
||||
|
||||
if len(centers) < 1:
|
||||
raise ValueError("centers must have length >= 1")
|
||||
|
||||
self.eps = float(eps)
|
||||
self.bandwidth_min = float(bandwidth_min)
|
||||
self.bandwidth_max = float(bandwidth_max)
|
||||
self.lambda_sigma_reg = float(lambda_sigma_reg)
|
||||
self.sigma_reg_target = None if sigma_reg_target is None else float(
|
||||
sigma_reg_target)
|
||||
self.bandwidth_init = float(bandwidth_init)
|
||||
self.return_dict = bool(return_dict)
|
||||
|
||||
self.register_buffer(
|
||||
"bin_edges",
|
||||
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"centers",
|
||||
torch.tensor([float(x) for x in centers], dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
if self.bandwidth_init <= 0:
|
||||
raise ValueError("bandwidth_init must be > 0")
|
||||
self.log_sigma = nn.Parameter(torch.tensor(
|
||||
math.log(self.bandwidth_init), dtype=torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
|
||||
# Stable normal CDF via erf.
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||||
|
||||
@staticmethod
|
||||
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
|
||||
# Stable normal survival function via erfc.
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * torch.erfc(z / math.sqrt(2.0))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]:
|
||||
if logits.ndim != 2:
|
||||
raise ValueError(
|
||||
f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}")
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError("target_events and dt must be 1D tensors")
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch among logits, target_events, dt")
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
device = logits.device
|
||||
dtype = logits.dtype
|
||||
|
||||
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
|
||||
centers = self.centers.to(device=device, dtype=dtype)
|
||||
bsz = logits.shape[0]
|
||||
r = int(centers.numel())
|
||||
jr = int(logits.shape[1] - 1)
|
||||
if jr <= 0:
|
||||
raise ValueError(
|
||||
"logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)")
|
||||
if jr % r != 0:
|
||||
raise ValueError(
|
||||
f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}")
|
||||
j = jr // r
|
||||
|
||||
# 1) Stable global weights (includes w0).
|
||||
w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R)
|
||||
w0 = w_all[:, 0]
|
||||
w = w_all[:, 1:].view(bsz, j, r)
|
||||
|
||||
# 2) Determine event bin index.
|
||||
k = int(bin_edges.numel() - 1)
|
||||
if k < 1:
|
||||
raise ValueError("bin_edges must define at least one bin")
|
||||
|
||||
# v2: dt is always continuous time (float), map to bin via searchsorted.
|
||||
dt_f = dt.to(device=device, dtype=dtype)
|
||||
bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1
|
||||
bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long)
|
||||
|
||||
left = bin_edges[bin_idx]
|
||||
right = bin_edges[bin_idx + 1]
|
||||
|
||||
# 3) Stable log(t) clamp.
|
||||
if float(self.bin_edges[1]) > 0.0:
|
||||
t_min = float(self.bin_edges[1]) * 1e-6
|
||||
else:
|
||||
t_min = 1e-12
|
||||
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
|
||||
|
||||
left_is_zero = left <= 0
|
||||
|
||||
# For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly
|
||||
# when left<=0 (instead of approximating via t_min).
|
||||
left_clamped = torch.clamp(left, min=t_min_t)
|
||||
log_left = torch.log(left_clamped)
|
||||
is_inf = torch.isinf(right)
|
||||
# right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value.
|
||||
right_safe = torch.where(is_inf, left_clamped,
|
||||
torch.clamp(right, min=t_min_t))
|
||||
log_right = torch.log(right_safe)
|
||||
|
||||
sigma = torch.clamp(self.log_sigma.to(
|
||||
device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max)
|
||||
|
||||
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
||||
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
|
||||
z_left = torch.clamp(z_left, -12.0, 12.0)
|
||||
z_right = torch.clamp(z_right, -12.0, 12.0)
|
||||
|
||||
cdf_left = self._normal_cdf(z_left)
|
||||
# Treat the first-bin left boundary exactly as 0 in CDF.
|
||||
if left_is_zero.any():
|
||||
cdf_left = torch.where(
|
||||
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
|
||||
cdf_right = self._normal_cdf(z_right)
|
||||
delta_finite = cdf_right - cdf_left
|
||||
delta_last = self._normal_sf(z_left)
|
||||
# If left<=0, SF(left)=1 exactly.
|
||||
if left_is_zero.any():
|
||||
delta_last = torch.where(
|
||||
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
|
||||
delta_basis = torch.where(
|
||||
is_inf.unsqueeze(-1), delta_last, delta_finite)
|
||||
delta_basis = torch.clamp(delta_basis, min=0.0)
|
||||
|
||||
# 4) Gather per-sample cause weights and compute event mass.
|
||||
cause = target_events.to(device=device, dtype=torch.long)
|
||||
if (cause < 0).any() or (cause >= j).any():
|
||||
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
|
||||
|
||||
b_idx = torch.arange(bsz, device=device)
|
||||
w_cause = w[b_idx, cause, :] # (B, R)
|
||||
|
||||
m = (w_cause * delta_basis).sum(dim=-1)
|
||||
m = torch.clamp(m, min=self.eps)
|
||||
nll_vec = -torch.log(m)
|
||||
|
||||
if reduction == "mean":
|
||||
nll: torch.Tensor = nll_vec.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll_vec.sum()
|
||||
else:
|
||||
nll = nll_vec
|
||||
|
||||
sigma_penalty = torch.zeros((), device=device, dtype=dtype)
|
||||
if self.lambda_sigma_reg > 0.0:
|
||||
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
|
||||
sigma_penalty = (self.log_sigma.to(
|
||||
device=device, dtype=dtype) - math.log(float(target))) ** 2
|
||||
reg = sigma_penalty * float(self.lambda_sigma_reg)
|
||||
|
||||
if not self.return_dict:
|
||||
return nll, reg
|
||||
|
||||
return {
|
||||
"nll": nll,
|
||||
"reg": reg,
|
||||
"nll_vec": nll_vec,
|
||||
"sigma": sigma.detach(),
|
||||
"avg_w0": w0.mean().detach(),
|
||||
"min_delta_basis": delta_basis.min().detach(),
|
||||
"mean_m": m.mean().detach(),
|
||||
"sigma_penalty": sigma_penalty.detach(),
|
||||
"bin_idx": bin_idx.detach(),
|
||||
}
|
||||
|
||||
1
model.py
1
model.py
@@ -2,7 +2,6 @@ import numpy as np
|
||||
from typing import Optional, List
|
||||
from backbones import Block
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ import numpy as np # Numerical operations
|
||||
|
||||
# CSV mapping field IDs to human-readable names
|
||||
field_map_file = "field_ids_enriched.csv"
|
||||
field_dict = {} # Map original field ID -> new column name
|
||||
# Map original field ID -> new column name
|
||||
field_dict = {}
|
||||
tabular_fields = [] # List of tabular feature column names
|
||||
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
|
||||
next(f) # skip header line
|
||||
|
||||
85
train.py
85
train.py
@@ -1,4 +1,4 @@
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
@@ -13,16 +13,16 @@ import os
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal, Sequence
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif',
|
||||
'lognormal_basis_hazard'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -34,6 +34,15 @@ class TrainConfig:
|
||||
default_factory=lambda: [0.0, 0.24, 0.72,
|
||||
1.61, 3.84, 10.0, 31.0, float('inf')]
|
||||
)
|
||||
# LogNormalBasisHazardLoss specific
|
||||
lognormal_centers: Optional[Sequence[float]] = field(
|
||||
default_factory=list) # mu_r in log-time
|
||||
loss_eps: float = 1e-8
|
||||
bandwidth_init: float = 0.7
|
||||
bandwidth_min: float = 1e-3
|
||||
bandwidth_max: float = 10.0
|
||||
lambda_sigma_reg: float = 1e-4
|
||||
sigma_reg_target: Optional[float] = None
|
||||
rank: int = 16
|
||||
# SapDelphi specific
|
||||
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
||||
@@ -64,7 +73,7 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif'],
|
||||
choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
@@ -79,6 +88,24 @@ def parse_args() -> TrainConfig:
|
||||
help="Dropout probability.")
|
||||
parser.add_argument("--lambda_reg", type=float,
|
||||
default=1e-4, help="Regularization weight.")
|
||||
parser.add_argument(
|
||||
"--lognormal_centers",
|
||||
type=float,
|
||||
nargs='*',
|
||||
default=None,
|
||||
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
|
||||
parser.add_argument("--loss_eps", type=float, default=1e-8,
|
||||
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
|
||||
parser.add_argument("--bandwidth_init", type=float, default=0.7,
|
||||
help="Initial sigma for LogNormalBasisHazardLoss.")
|
||||
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
|
||||
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
|
||||
parser.add_argument("--bandwidth_max", type=float, default=10.0,
|
||||
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
|
||||
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
|
||||
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
|
||||
parser.add_argument("--sigma_reg_target", type=float, default=None,
|
||||
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
|
||||
parser.add_argument("--rank", type=int, default=16,
|
||||
help="Rank for low-rank parameterization (if applicable).")
|
||||
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
|
||||
@@ -118,7 +145,36 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
help="Device to use for training.")
|
||||
args = parser.parse_args()
|
||||
return TrainConfig(**vars(args))
|
||||
|
||||
cfg = TrainConfig(**vars(args))
|
||||
|
||||
# If lognormal_centers are not provided, derive a reasonable default from bin_edges.
|
||||
if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0:
|
||||
edges = [float(x) for x in cfg.bin_edges]
|
||||
finite = [e for e in edges if math.isfinite(e)]
|
||||
if len(finite) < 2:
|
||||
raise ValueError(
|
||||
"bin_edges must contain at least two finite edges to derive lognormal_centers")
|
||||
e1 = finite[1]
|
||||
t_min = (e1 * 1e-3) if e1 > 0 else 1e-12
|
||||
# Build one center per bin (including the +inf last bin if present).
|
||||
centers: list[float] = []
|
||||
for i in range(1, len(edges)):
|
||||
left = float(edges[i - 1])
|
||||
right = float(edges[i])
|
||||
if i == 1 and left <= 0.0:
|
||||
left_pos = t_min
|
||||
else:
|
||||
left_pos = max(left, t_min)
|
||||
if math.isinf(right):
|
||||
mid = max(left_pos * 2.0, left_pos + 1e-6)
|
||||
else:
|
||||
right_pos = max(right, t_min)
|
||||
mid = math.sqrt(left_pos * right_pos)
|
||||
centers.append(math.log(max(mid, t_min)))
|
||||
cfg.lognormal_centers = centers
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def get_num_params(model: nn.Module) -> int:
|
||||
@@ -205,6 +261,23 @@ class Trainer:
|
||||
).to(self.device)
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
elif cfg.loss_type == "lognormal_basis_hazard":
|
||||
r = len(cfg.lognormal_centers)
|
||||
if r <= 0:
|
||||
raise ValueError(
|
||||
"lognormal_centers must be non-empty for lognormal_basis_hazard")
|
||||
self.criterion = LogNormalBasisHazardLoss(
|
||||
bin_edges=cfg.bin_edges,
|
||||
centers=cfg.lognormal_centers,
|
||||
eps=cfg.loss_eps,
|
||||
bandwidth_init=cfg.bandwidth_init,
|
||||
bandwidth_min=cfg.bandwidth_min,
|
||||
bandwidth_max=cfg.bandwidth_max,
|
||||
lambda_sigma_reg=cfg.lambda_sigma_reg,
|
||||
sigma_reg_target=cfg.sigma_reg_target,
|
||||
).to(self.device)
|
||||
# logits shape (M, 1 + J*R)
|
||||
out_dims = [1 + dataset.n_disease * r]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user