diff --git a/evaluate_models.py b/evaluate_models.py index cf86d2e..0bfa9b3 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -8,7 +8,7 @@ import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import torch @@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256 class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi - loss_type: str # exponential | discrete_time_cif + loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard full_cov: bool checkpoint_path: str @@ -420,6 +420,72 @@ def cifs_from_discrete_time_logits( return cif, survival +def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor: + z = torch.clamp(z, -12.0, 12.0) + return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) + + +def cifs_from_lognormal_basis_logits( + logits: torch.Tensor, + centers: Sequence[float], + sigma: torch.Tensor, + taus: Sequence[float], + *, + bin_edges: Optional[Sequence[float]] = None, + return_survival: bool = False, +) -> torch.Tensor: + """Convert LogNormalBasisHazardLoss logits -> CIFs at taus. + + logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions. + centers: length R, in log-time. + sigma: scalar tensor (already clamped) in log-time units. + taus: horizons in the same units as training bin_edges (years). + """ + if logits.ndim != 2: + raise ValueError("Expected logits shape (B, 1+K*R)") + if sigma.ndim != 0: + raise ValueError("sigma must be a scalar tensor") + + device = logits.device + dtype = logits.dtype + + centers_t = torch.tensor([float(x) + for x in centers], device=device, dtype=dtype) + r = int(centers_t.numel()) + jr = int(logits.shape[1] - 1) + if jr <= 0 or (jr % r) != 0: + raise ValueError("logits.shape[1]-1 must be divisible by R") + k = jr // r + + # Stable t_min clamp (aligns with training loss rule). + t_min = 1e-12 + if bin_edges is not None: + edges = [float(x) for x in bin_edges] + if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0: + t_min = edges[1] * 1e-6 + t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype) + + taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype) + taus_t = torch.clamp(taus_t, min=t_min_t) + log_tau = torch.log(taus_t) # (H,) + + # (H,R) + z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma + cdf = _normal_cdf_stable(z) + + w_all = torch.softmax(logits, dim=-1) + w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R) + + cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H) + + if not return_survival: + return cif + + survival = 1.0 - cif.sum(dim=1) # (B,H) + survival = torch.clamp(survival, min=0.0, max=1.0) + return cif, survival + + # ============================================================ # CIF integrity checks # ============================================================ @@ -1192,15 +1258,31 @@ def instantiate_model_and_head( dataset: HealthDataset, device: str, checkpoint_path: str = "", -) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: +) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]: model_type = str(cfg["model_type"]) loss_type = str(cfg["loss_type"]) + loss_params: Dict[str, Any] = {} + if loss_type == "exponential": out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] + elif loss_type == "lognormal_basis_hazard": + centers = cfg.get("lognormal_centers", None) + if centers is None: + centers = cfg.get("centers", None) + if not isinstance(centers, list) or len(centers) == 0: + raise ValueError( + "lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json" + ) + out_dims = [1 + dataset.n_disease * len(centers)] + loss_params["centers"] = centers + loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3)) + loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0)) + loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7)) + loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8)) else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") @@ -1249,7 +1331,7 @@ def instantiate_model_and_head( head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) - return backbone, head, loss_type, bin_edges + return backbone, head, loss_type, bin_edges, loss_params @torch.no_grad() @@ -1258,6 +1340,7 @@ def predict_cifs_for_model( head: torch.nn.Module, loss_type: str, bin_edges: Sequence[float], + loss_params: Dict[str, Any], loader: DataLoader, device: str, offset_years: float, @@ -1316,6 +1399,20 @@ def predict_cifs_for_model( cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) + elif loss_type == "lognormal_basis_hazard": + centers = loss_params.get("centers", None) + sigma = loss_params.get("sigma", None) + if centers is None or sigma is None: + raise ValueError( + "lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']") + cif_full, survival = cifs_from_lognormal_basis_logits( + logits, + centers=centers, + sigma=sigma, + taus=eval_horizons, + bin_edges=bin_edges, + return_survival=True, + ) else: raise ValueError(f"Unsupported loss_type: {loss_type}") @@ -1823,12 +1920,27 @@ def main() -> int: collate_fn=health_collate_fn, ) - backbone, head, loss_type, bin_edges = instantiate_model_and_head( - cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) ckpt = torch.load(spec.checkpoint_path, map_location=args.device) + + backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head( + cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) backbone.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) + if loss_type == "lognormal_basis_hazard": + crit_state = ckpt.get("criterion_state_dict", {}) + log_sigma = crit_state.get("log_sigma", None) + if isinstance(log_sigma, torch.Tensor): + log_sigma_t = log_sigma.to(device=args.device) + sigma = torch.exp(log_sigma_t) + else: + sigma = torch.tensor(float(loss_params.get( + "bandwidth_init", 0.7)), device=args.device) + bmin = float(loss_params.get("bandwidth_min", 1e-3)) + bmax = float(loss_params.get("bandwidth_max", 10.0)) + sigma = torch.clamp(sigma, min=bmin, max=bmax) + loss_params["sigma"] = sigma + ( cif_full, survival, @@ -1839,6 +1951,7 @@ def main() -> int: head, loss_type, bin_edges, + loss_params, loader, args.device, args.offset_years, diff --git a/losses.py b/losses.py index 35a5942..e251e4d 100644 --- a/losses.py +++ b/losses.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -258,3 +258,233 @@ class DiscreteTimeCIFNLLLoss(nn.Module): F.nll_loss(logp_at_event_bin, target_events, reduction="mean") return nll, reg + + +class LogNormalBasisHazardLoss(nn.Module): + """Event-only competing risks loss using lognormal basis (Gaussian on log-time). + + This loss models cause-specific CIF as a mixture of lognormal basis CDFs: + + F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma) + + Training uses *bin probability mass* (Delta CIF / interval mass). There is + **no censoring**: every sample is an observed event with a valid cause. + + Logits interface: + logits: (B, 1 + J*R) + logits[:, 0] -> w0 (survival mass / never-event) + logits[:, 1:] -> flattened (j,r) in row-major order: j then r + index = 1 + j*R + r + + Forward interface (must match): + forward(logits, target_events, dt, reduction) + """ + + def __init__( + self, + bin_edges: Sequence[float], + centers: Sequence[float], + *, + eps: float = 1e-8, + bandwidth_init: float = 0.5, + bandwidth_min: float = 1e-3, + bandwidth_max: float = 10.0, + lambda_sigma_reg: float = 0.0, + sigma_reg_target: Optional[float] = None, + return_dict: bool = False, + ): + super().__init__() + + if len(bin_edges) < 2: + raise ValueError("bin_edges must have length >= 2") + # allow last edge to be +inf + for i in range(1, len(bin_edges)): + prev = float(bin_edges[i - 1]) + cur = float(bin_edges[i]) + if math.isinf(prev): + raise ValueError( + "bin_edges cannot have +inf except possibly as the last edge") + if i == len(bin_edges) - 1 and math.isinf(cur): + if not (cur > prev): + raise ValueError("bin_edges must be strictly increasing") + else: + if not (cur > prev): + raise ValueError("bin_edges must be strictly increasing") + if float(bin_edges[0]) < 0.0: + raise ValueError("bin_edges[0] must be >= 0") + + if len(centers) < 1: + raise ValueError("centers must have length >= 1") + + self.eps = float(eps) + self.bandwidth_min = float(bandwidth_min) + self.bandwidth_max = float(bandwidth_max) + self.lambda_sigma_reg = float(lambda_sigma_reg) + self.sigma_reg_target = None if sigma_reg_target is None else float( + sigma_reg_target) + self.bandwidth_init = float(bandwidth_init) + self.return_dict = bool(return_dict) + + self.register_buffer( + "bin_edges", + torch.tensor([float(x) for x in bin_edges], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "centers", + torch.tensor([float(x) for x in centers], dtype=torch.float32), + persistent=False, + ) + + if self.bandwidth_init <= 0: + raise ValueError("bandwidth_init must be > 0") + self.log_sigma = nn.Parameter(torch.tensor( + math.log(self.bandwidth_init), dtype=torch.float32)) + + @staticmethod + def _normal_cdf(z: torch.Tensor) -> torch.Tensor: + # Stable normal CDF via erf. + z = torch.clamp(z, -12.0, 12.0) + return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) + + @staticmethod + def _normal_sf(z: torch.Tensor) -> torch.Tensor: + # Stable normal survival function via erfc. + z = torch.clamp(z, -12.0, 12.0) + return 0.5 * torch.erfc(z / math.sqrt(2.0)) + + def forward( + self, + logits: torch.Tensor, + target_events: torch.Tensor, + dt: torch.Tensor, + reduction: str = "mean", + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]: + if logits.ndim != 2: + raise ValueError( + f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}") + if target_events.ndim != 1 or dt.ndim != 1: + raise ValueError("target_events and dt must be 1D tensors") + if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: + raise ValueError( + "Batch size mismatch among logits, target_events, dt") + if reduction not in {"mean", "sum", "none"}: + raise ValueError("reduction must be one of {'mean','sum','none'}") + + device = logits.device + dtype = logits.dtype + + bin_edges = self.bin_edges.to(device=device, dtype=dtype) + centers = self.centers.to(device=device, dtype=dtype) + bsz = logits.shape[0] + r = int(centers.numel()) + jr = int(logits.shape[1] - 1) + if jr <= 0: + raise ValueError( + "logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)") + if jr % r != 0: + raise ValueError( + f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}") + j = jr // r + + # 1) Stable global weights (includes w0). + w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R) + w0 = w_all[:, 0] + w = w_all[:, 1:].view(bsz, j, r) + + # 2) Determine event bin index. + k = int(bin_edges.numel() - 1) + if k < 1: + raise ValueError("bin_edges must define at least one bin") + + # v2: dt is always continuous time (float), map to bin via searchsorted. + dt_f = dt.to(device=device, dtype=dtype) + bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1 + bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long) + + left = bin_edges[bin_idx] + right = bin_edges[bin_idx + 1] + + # 3) Stable log(t) clamp. + if float(self.bin_edges[1]) > 0.0: + t_min = float(self.bin_edges[1]) * 1e-6 + else: + t_min = 1e-12 + t_min_t = torch.tensor(t_min, device=device, dtype=dtype) + + left_is_zero = left <= 0 + + # For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly + # when left<=0 (instead of approximating via t_min). + left_clamped = torch.clamp(left, min=t_min_t) + log_left = torch.log(left_clamped) + is_inf = torch.isinf(right) + # right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value. + right_safe = torch.where(is_inf, left_clamped, + torch.clamp(right, min=t_min_t)) + log_right = torch.log(right_safe) + + sigma = torch.clamp(self.log_sigma.to( + device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max) + + z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma + z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma + z_left = torch.clamp(z_left, -12.0, 12.0) + z_right = torch.clamp(z_right, -12.0, 12.0) + + cdf_left = self._normal_cdf(z_left) + # Treat the first-bin left boundary exactly as 0 in CDF. + if left_is_zero.any(): + cdf_left = torch.where( + left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left) + cdf_right = self._normal_cdf(z_right) + delta_finite = cdf_right - cdf_left + delta_last = self._normal_sf(z_left) + # If left<=0, SF(left)=1 exactly. + if left_is_zero.any(): + delta_last = torch.where( + left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last) + delta_basis = torch.where( + is_inf.unsqueeze(-1), delta_last, delta_finite) + delta_basis = torch.clamp(delta_basis, min=0.0) + + # 4) Gather per-sample cause weights and compute event mass. + cause = target_events.to(device=device, dtype=torch.long) + if (cause < 0).any() or (cause >= j).any(): + raise ValueError(f"target_events must be in [0, J-1] where J={j}") + + b_idx = torch.arange(bsz, device=device) + w_cause = w[b_idx, cause, :] # (B, R) + + m = (w_cause * delta_basis).sum(dim=-1) + m = torch.clamp(m, min=self.eps) + nll_vec = -torch.log(m) + + if reduction == "mean": + nll: torch.Tensor = nll_vec.mean() + elif reduction == "sum": + nll = nll_vec.sum() + else: + nll = nll_vec + + sigma_penalty = torch.zeros((), device=device, dtype=dtype) + if self.lambda_sigma_reg > 0.0: + target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target + sigma_penalty = (self.log_sigma.to( + device=device, dtype=dtype) - math.log(float(target))) ** 2 + reg = sigma_penalty * float(self.lambda_sigma_reg) + + if not self.return_dict: + return nll, reg + + return { + "nll": nll, + "reg": reg, + "nll_vec": nll_vec, + "sigma": sigma.detach(), + "avg_w0": w0.mean().detach(), + "min_delta_basis": delta_basis.min().detach(), + "mean_m": m.mean().detach(), + "sigma_penalty": sigma_penalty.detach(), + "bin_idx": bin_idx.detach(), + } diff --git a/model.py b/model.py index 462cc82..0f01d07 100644 --- a/model.py +++ b/model.py @@ -2,7 +2,6 @@ import numpy as np from typing import Optional, List from backbones import Block from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder -import torch.nn.functional as F import torch.nn as nn import torch diff --git a/prepare_data.py b/prepare_data.py index e16f690..9b0b287 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -4,7 +4,8 @@ import numpy as np # Numerical operations # CSV mapping field IDs to human-readable names field_map_file = "field_ids_enriched.csv" -field_dict = {} # Map original field ID -> new column name +# Map original field ID -> new column name +field_dict = {} tabular_fields = [] # List of tabular feature column names with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file next(f) # skip header line diff --git a/train.py b/train.py index 947bf1b..15deed4 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt +from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from tqdm import tqdm @@ -13,16 +13,16 @@ import os import time import argparse import math -import sys from dataclasses import asdict, dataclass, field -from typing import Literal, Sequence +from typing import Literal, Optional, Sequence @dataclass class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' - loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential' + loss_type: Literal['exponential', 'discrete_time_cif', + 'lognormal_basis_hazard'] = 'exponential' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 @@ -34,6 +34,15 @@ class TrainConfig: default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')] ) + # LogNormalBasisHazardLoss specific + lognormal_centers: Optional[Sequence[float]] = field( + default_factory=list) # mu_r in log-time + loss_eps: float = 1e-8 + bandwidth_init: float = 0.7 + bandwidth_min: float = 1e-3 + bandwidth_max: float = 10.0 + lambda_sigma_reg: float = 1e-4 + sigma_reg_target: Optional[float] = None rank: int = 16 # SapDelphi specific pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" @@ -64,7 +73,7 @@ def parse_args() -> TrainConfig: parser.add_argument( "--loss_type", type=str, - choices=['exponential', 'discrete_time_cif'], + choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'], default='exponential', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ @@ -79,6 +88,24 @@ def parse_args() -> TrainConfig: help="Dropout probability.") parser.add_argument("--lambda_reg", type=float, default=1e-4, help="Regularization weight.") + parser.add_argument( + "--lognormal_centers", + type=float, + nargs='*', + default=None, + help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.") + parser.add_argument("--loss_eps", type=float, default=1e-8, + help="Epsilon for LogNormalBasisHazardLoss log clamp.") + parser.add_argument("--bandwidth_init", type=float, default=0.7, + help="Initial sigma for LogNormalBasisHazardLoss.") + parser.add_argument("--bandwidth_min", type=float, default=1e-3, + help="Minimum sigma clamp for LogNormalBasisHazardLoss.") + parser.add_argument("--bandwidth_max", type=float, default=10.0, + help="Maximum sigma clamp for LogNormalBasisHazardLoss.") + parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4, + help="Sigma regularization strength for LogNormalBasisHazardLoss.") + parser.add_argument("--sigma_reg_target", type=float, default=None, + help="Optional sigma target for regularization (otherwise uses bandwidth_init).") parser.add_argument("--rank", type=int, default=16, help="Rank for low-rank parameterization (if applicable).") parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy", @@ -118,7 +145,36 @@ def parse_args() -> TrainConfig: parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to use for training.") args = parser.parse_args() - return TrainConfig(**vars(args)) + + cfg = TrainConfig(**vars(args)) + + # If lognormal_centers are not provided, derive a reasonable default from bin_edges. + if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0: + edges = [float(x) for x in cfg.bin_edges] + finite = [e for e in edges if math.isfinite(e)] + if len(finite) < 2: + raise ValueError( + "bin_edges must contain at least two finite edges to derive lognormal_centers") + e1 = finite[1] + t_min = (e1 * 1e-3) if e1 > 0 else 1e-12 + # Build one center per bin (including the +inf last bin if present). + centers: list[float] = [] + for i in range(1, len(edges)): + left = float(edges[i - 1]) + right = float(edges[i]) + if i == 1 and left <= 0.0: + left_pos = t_min + else: + left_pos = max(left, t_min) + if math.isinf(right): + mid = max(left_pos * 2.0, left_pos + 1e-6) + else: + right_pos = max(right, t_min) + mid = math.sqrt(left_pos * right_pos) + centers.append(math.log(max(mid, t_min))) + cfg.lognormal_centers = centers + + return cfg def get_num_params(model: nn.Module) -> int: @@ -205,6 +261,23 @@ class Trainer: ).to(self.device) # logits shape (M, K+1, n_bins+1) out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)] + elif cfg.loss_type == "lognormal_basis_hazard": + r = len(cfg.lognormal_centers) + if r <= 0: + raise ValueError( + "lognormal_centers must be non-empty for lognormal_basis_hazard") + self.criterion = LogNormalBasisHazardLoss( + bin_edges=cfg.bin_edges, + centers=cfg.lognormal_centers, + eps=cfg.loss_eps, + bandwidth_init=cfg.bandwidth_init, + bandwidth_min=cfg.bandwidth_min, + bandwidth_max=cfg.bandwidth_max, + lambda_sigma_reg=cfg.lambda_sigma_reg, + sigma_reg_target=cfg.sigma_reg_target, + ).to(self.device) + # logits shape (M, 1 + J*R) + out_dims = [1 + dataset.n_disease * r] else: raise ValueError(f"Unsupported loss type: {cfg.loss_type}")