Add LogNormalBasisHazardLoss implementation and update training configuration
This commit is contained in:
@@ -8,7 +8,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
|
||||
class ModelSpec:
|
||||
name: str
|
||||
model_type: str # delphi_fork | sap_delphi
|
||||
loss_type: str # exponential | discrete_time_cif
|
||||
loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
|
||||
full_cov: bool
|
||||
checkpoint_path: str
|
||||
|
||||
@@ -420,6 +420,72 @@ def cifs_from_discrete_time_logits(
|
||||
return cif, survival
|
||||
|
||||
|
||||
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
|
||||
z = torch.clamp(z, -12.0, 12.0)
|
||||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def cifs_from_lognormal_basis_logits(
|
||||
logits: torch.Tensor,
|
||||
centers: Sequence[float],
|
||||
sigma: torch.Tensor,
|
||||
taus: Sequence[float],
|
||||
*,
|
||||
bin_edges: Optional[Sequence[float]] = None,
|
||||
return_survival: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
|
||||
|
||||
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
|
||||
centers: length R, in log-time.
|
||||
sigma: scalar tensor (already clamped) in log-time units.
|
||||
taus: horizons in the same units as training bin_edges (years).
|
||||
"""
|
||||
if logits.ndim != 2:
|
||||
raise ValueError("Expected logits shape (B, 1+K*R)")
|
||||
if sigma.ndim != 0:
|
||||
raise ValueError("sigma must be a scalar tensor")
|
||||
|
||||
device = logits.device
|
||||
dtype = logits.dtype
|
||||
|
||||
centers_t = torch.tensor([float(x)
|
||||
for x in centers], device=device, dtype=dtype)
|
||||
r = int(centers_t.numel())
|
||||
jr = int(logits.shape[1] - 1)
|
||||
if jr <= 0 or (jr % r) != 0:
|
||||
raise ValueError("logits.shape[1]-1 must be divisible by R")
|
||||
k = jr // r
|
||||
|
||||
# Stable t_min clamp (aligns with training loss rule).
|
||||
t_min = 1e-12
|
||||
if bin_edges is not None:
|
||||
edges = [float(x) for x in bin_edges]
|
||||
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
|
||||
t_min = edges[1] * 1e-6
|
||||
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
|
||||
|
||||
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
|
||||
taus_t = torch.clamp(taus_t, min=t_min_t)
|
||||
log_tau = torch.log(taus_t) # (H,)
|
||||
|
||||
# (H,R)
|
||||
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
|
||||
cdf = _normal_cdf_stable(z)
|
||||
|
||||
w_all = torch.softmax(logits, dim=-1)
|
||||
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
|
||||
|
||||
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
survival = 1.0 - cif.sum(dim=1) # (B,H)
|
||||
survival = torch.clamp(survival, min=0.0, max=1.0)
|
||||
return cif, survival
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CIF integrity checks
|
||||
# ============================================================
|
||||
@@ -1192,15 +1258,31 @@ def instantiate_model_and_head(
|
||||
dataset: HealthDataset,
|
||||
device: str,
|
||||
checkpoint_path: str = "",
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
|
||||
model_type = str(cfg["model_type"])
|
||||
loss_type = str(cfg["loss_type"])
|
||||
|
||||
loss_params: Dict[str, Any] = {}
|
||||
|
||||
if loss_type == "exponential":
|
||||
out_dims = [dataset.n_disease]
|
||||
elif loss_type == "discrete_time_cif":
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
centers = cfg.get("lognormal_centers", None)
|
||||
if centers is None:
|
||||
centers = cfg.get("centers", None)
|
||||
if not isinstance(centers, list) or len(centers) == 0:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
|
||||
)
|
||||
out_dims = [1 + dataset.n_disease * len(centers)]
|
||||
loss_params["centers"] = centers
|
||||
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
|
||||
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
|
||||
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
|
||||
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||||
|
||||
@@ -1249,7 +1331,7 @@ def instantiate_model_and_head(
|
||||
|
||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
return backbone, head, loss_type, bin_edges
|
||||
return backbone, head, loss_type, bin_edges, loss_params
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1258,6 +1340,7 @@ def predict_cifs_for_model(
|
||||
head: torch.nn.Module,
|
||||
loss_type: str,
|
||||
bin_edges: Sequence[float],
|
||||
loss_params: Dict[str, Any],
|
||||
loader: DataLoader,
|
||||
device: str,
|
||||
offset_years: float,
|
||||
@@ -1316,6 +1399,20 @@ def predict_cifs_for_model(
|
||||
cif_full, survival = cifs_from_discrete_time_logits(
|
||||
# (B,K,H), (B,H)
|
||||
logits, bin_edges, eval_horizons, return_survival=True)
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
centers = loss_params.get("centers", None)
|
||||
sigma = loss_params.get("sigma", None)
|
||||
if centers is None or sigma is None:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
|
||||
cif_full, survival = cifs_from_lognormal_basis_logits(
|
||||
logits,
|
||||
centers=centers,
|
||||
sigma=sigma,
|
||||
taus=eval_horizons,
|
||||
bin_edges=bin_edges,
|
||||
return_survival=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||
|
||||
@@ -1823,12 +1920,27 @@ def main() -> int:
|
||||
collate_fn=health_collate_fn,
|
||||
)
|
||||
|
||||
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
|
||||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||||
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
|
||||
|
||||
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
|
||||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||||
|
||||
if loss_type == "lognormal_basis_hazard":
|
||||
crit_state = ckpt.get("criterion_state_dict", {})
|
||||
log_sigma = crit_state.get("log_sigma", None)
|
||||
if isinstance(log_sigma, torch.Tensor):
|
||||
log_sigma_t = log_sigma.to(device=args.device)
|
||||
sigma = torch.exp(log_sigma_t)
|
||||
else:
|
||||
sigma = torch.tensor(float(loss_params.get(
|
||||
"bandwidth_init", 0.7)), device=args.device)
|
||||
bmin = float(loss_params.get("bandwidth_min", 1e-3))
|
||||
bmax = float(loss_params.get("bandwidth_max", 10.0))
|
||||
sigma = torch.clamp(sigma, min=bmin, max=bmax)
|
||||
loss_params["sigma"] = sigma
|
||||
|
||||
(
|
||||
cif_full,
|
||||
survival,
|
||||
@@ -1839,6 +1951,7 @@ def main() -> int:
|
||||
head,
|
||||
loss_type,
|
||||
bin_edges,
|
||||
loss_params,
|
||||
loader,
|
||||
args.device,
|
||||
args.offset_years,
|
||||
|
||||
Reference in New Issue
Block a user