Add LogNormalBasisHazardLoss implementation and update training configuration

This commit is contained in:
2026-01-13 15:59:20 +08:00
parent d8b322cbee
commit 1df02d85d7
5 changed files with 431 additions and 15 deletions

View File

@@ -8,7 +8,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif
loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
full_cov: bool
checkpoint_path: str
@@ -420,6 +420,72 @@ def cifs_from_discrete_time_logits(
return cif, survival
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
def cifs_from_lognormal_basis_logits(
logits: torch.Tensor,
centers: Sequence[float],
sigma: torch.Tensor,
taus: Sequence[float],
*,
bin_edges: Optional[Sequence[float]] = None,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
centers: length R, in log-time.
sigma: scalar tensor (already clamped) in log-time units.
taus: horizons in the same units as training bin_edges (years).
"""
if logits.ndim != 2:
raise ValueError("Expected logits shape (B, 1+K*R)")
if sigma.ndim != 0:
raise ValueError("sigma must be a scalar tensor")
device = logits.device
dtype = logits.dtype
centers_t = torch.tensor([float(x)
for x in centers], device=device, dtype=dtype)
r = int(centers_t.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0 or (jr % r) != 0:
raise ValueError("logits.shape[1]-1 must be divisible by R")
k = jr // r
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
if bin_edges is not None:
edges = [float(x) for x in bin_edges]
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
taus_t = torch.clamp(taus_t, min=t_min_t)
log_tau = torch.log(taus_t) # (H,)
# (H,R)
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
cdf = _normal_cdf_stable(z)
w_all = torch.softmax(logits, dim=-1)
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
if not return_survival:
return cif
survival = 1.0 - cif.sum(dim=1) # (B,H)
survival = torch.clamp(survival, min=0.0, max=1.0)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
@@ -1192,15 +1258,31 @@ def instantiate_model_and_head(
dataset: HealthDataset,
device: str,
checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
loss_params: Dict[str, Any] = {}
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "lognormal_basis_hazard":
centers = cfg.get("lognormal_centers", None)
if centers is None:
centers = cfg.get("centers", None)
if not isinstance(centers, list) or len(centers) == 0:
raise ValueError(
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
)
out_dims = [1 + dataset.n_disease * len(centers)]
loss_params["centers"] = centers
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
@@ -1249,7 +1331,7 @@ def instantiate_model_and_head(
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges
return backbone, head, loss_type, bin_edges, loss_params
@torch.no_grad()
@@ -1258,6 +1340,7 @@ def predict_cifs_for_model(
head: torch.nn.Module,
loss_type: str,
bin_edges: Sequence[float],
loss_params: Dict[str, Any],
loader: DataLoader,
device: str,
offset_years: float,
@@ -1316,6 +1399,20 @@ def predict_cifs_for_model(
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "lognormal_basis_hazard":
centers = loss_params.get("centers", None)
sigma = loss_params.get("sigma", None)
if centers is None or sigma is None:
raise ValueError(
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_logits(
logits,
centers=centers,
sigma=sigma,
taus=eval_horizons,
bin_edges=bin_edges,
return_survival=True,
)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
@@ -1823,12 +1920,27 @@ def main() -> int:
collate_fn=health_collate_fn,
)
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if loss_type == "lognormal_basis_hazard":
crit_state = ckpt.get("criterion_state_dict", {})
log_sigma = crit_state.get("log_sigma", None)
if isinstance(log_sigma, torch.Tensor):
log_sigma_t = log_sigma.to(device=args.device)
sigma = torch.exp(log_sigma_t)
else:
sigma = torch.tensor(float(loss_params.get(
"bandwidth_init", 0.7)), device=args.device)
bmin = float(loss_params.get("bandwidth_min", 1e-3))
bmax = float(loss_params.get("bandwidth_max", 10.0))
sigma = torch.clamp(sigma, min=bmin, max=bmax)
loss_params["sigma"] = sigma
(
cif_full,
survival,
@@ -1839,6 +1951,7 @@ def main() -> int:
head,
loss_type,
bin_edges,
loss_params,
loader,
args.device,
args.offset_years,