diff --git a/evaluate_models.py b/evaluate_models.py index cf86d2e..6c354db 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256 class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi - loss_type: str # exponential | discrete_time_cif + loss_type: str # exponential | discrete_time_cif | pwe_cif full_cov: bool checkpoint_path: str @@ -420,6 +420,94 @@ def cifs_from_discrete_time_logits( return cif, survival +def cifs_from_pwe_logits( + logits: torch.Tensor, + bin_edges: Sequence[float], + taus: Sequence[float], + eps: float = 1e-6, + return_survival: bool = False, +) -> torch.Tensor: + """Convert piecewise-exponential (PWE) hazard logits -> CIFs at taus. + + logits: (B, K, n_bins) # hazard logits per cause per bin + bin_edges: length n_bins+1, strictly increasing, finite last edge + taus: subset of finite bin edges (recommended) + + returns: (B, K, H) or (cif, survival) if return_survival + """ + if logits.ndim != 3: + raise ValueError("Expected logits shape (B, K, n_bins) for pwe_cif") + + edges = [float(x) for x in bin_edges] + if len(edges) < 2: + raise ValueError("bin_edges must have length >= 2") + if edges[0] != 0.0: + raise ValueError("bin_edges[0] must equal 0.0") + if not math.isfinite(edges[-1]): + raise ValueError( + "pwe_cif requires a finite last bin edge (no +inf). " + "If your training config uses +inf, drop it for PWE evaluation." + ) + + B, K, n_bins = logits.shape + if n_bins != (len(edges) - 1): + raise ValueError( + f"logits last dim n_bins={n_bins} must equal len(bin_edges)-1={len(edges)-1}" + ) + + # Convert logits -> hazards, then integrated hazards per bin. + hazards = F.softplus(logits) + eps # (B,K,n_bins) + dt_bins = torch.tensor( + [edges[i + 1] - edges[i] for i in range(n_bins)], + device=logits.device, + dtype=hazards.dtype, + ) # (n_bins,) + if not torch.isfinite(dt_bins).all() or not (dt_bins > 0).all(): + raise ValueError("All PWE bin widths must be finite and > 0") + + H_cause = hazards * dt_bins.view(1, 1, n_bins) # (B,K,n_bins) + H_total = H_cause.sum(dim=1) # (B,n_bins) + + # Survival at START of each bin u. + cum_total = torch.cumsum(H_total, dim=1) # (B,n_bins) + zeros = torch.zeros((B, 1), device=logits.device, dtype=hazards.dtype) + cum_prev = torch.cat([zeros, cum_total[:, :-1]], dim=1) # (B,n_bins) + S_prev = torch.exp(-cum_prev) # (B,n_bins) + + one_minus_surv_bin = 1.0 - torch.exp(-H_total) # (B,n_bins) + frac = H_cause / torch.clamp(H_total.unsqueeze(1), min=eps) # (B,K,n_bins) + + cif_incr = S_prev.unsqueeze(1) * frac * one_minus_surv_bin.unsqueeze(1) + cif_bins = torch.cumsum(cif_incr, dim=2) # (B,K,n_bins) at edges[1:] + + # Map tau -> edge index in edges[1:] + finite_edges = edges[1:] + finite_edges_arr = np.asarray(finite_edges, dtype=float) + tau_to_idx: List[int] = [] + for tau in taus: + tau_f = float(tau) + if not math.isfinite(tau_f): + raise ValueError("taus must be finite for pwe_cif") + diffs = np.abs(finite_edges_arr - tau_f) + j = int(np.argmin(diffs)) + if diffs[j] > 1e-6: + raise ValueError( + f"tau={tau_f} not close to any bin edge (min |edge-tau|={diffs[j]})" + ) + tau_to_idx.append(j) + + idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) + cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) + + if not return_survival: + return cif + + # Survival at each horizon is exp(-cum_total at that edge) + survival_bins = torch.exp(-cum_total) # (B,n_bins) + survival = survival_bins.index_select(dim=1, index=idx) # (B,H) + return cif, survival + + # ============================================================ # CIF integrity checks # ============================================================ @@ -1196,11 +1284,21 @@ def instantiate_model_and_head( model_type = str(cfg["model_type"]) loss_type = str(cfg["loss_type"]) + bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) if loss_type == "exponential": out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": - bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] + elif loss_type == "pwe_cif": + # Match training: drop +inf if present and evaluate up to the last finite edge. + pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] + if len(pwe_edges) < 2: + raise ValueError( + f"pwe_cif requires >=2 finite edges; got bin_edges={list(bin_edges)}" + ) + n_bins = len(pwe_edges) - 1 + out_dims = [dataset.n_disease, n_bins] + bin_edges = pwe_edges else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") @@ -1248,7 +1346,6 @@ def instantiate_model_and_head( raise ValueError(f"Unsupported model_type: {model_type}") head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) - bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) return backbone, head, loss_type, bin_edges @@ -1316,6 +1413,9 @@ def predict_cifs_for_model( cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) + elif loss_type == "pwe_cif": + cif_full, survival = cifs_from_pwe_logits( + logits, bin_edges, eval_horizons, return_survival=True) else: raise ValueError(f"Unsupported loss_type: {loss_type}") diff --git a/losses.py b/losses.py index 35a5942..3ef4de7 100644 --- a/losses.py +++ b/losses.py @@ -258,3 +258,149 @@ class DiscreteTimeCIFNLLLoss(nn.Module): F.nll_loss(logp_at_event_bin, target_events, reduction="mean") return nll, reg + + +class PiecewiseExponentialCIFNLLLoss(nn.Module): + """ + Piecewise-Exponential (PWE) cause-specific hazards with discrete-time CIF likelihood. + - No censoring + - No regularization (reg always 0) + - Forward signature matches DiscreteTimeCIFNLLLoss: + forward(logits, target_events, dt, reduction) -> (nll, reg) + + Expected shapes: + logits: (M, K, n_bins) # hazard logits per cause per bin + target_events: (M,) long in [0, K-1] + dt: (M,) event times (strictly > 0) + + bin_edges: + length n_bins+1, strictly increasing, bin_edges[0]==0, + and MUST be finite at the last edge (no +inf) for PWE. + """ + + def __init__( + self, + bin_edges: Sequence[float], + eps: float = 1e-6, + lambda_reg: float = 0.0, # kept for signature compatibility; UNUSED + ): + super().__init__() + + if len(bin_edges) < 2: + raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") + if float(bin_edges[0]) != 0.0: + raise ValueError("bin_edges[0] must equal 0.0") + for i in range(1, len(bin_edges)): + if not (float(bin_edges[i]) > float(bin_edges[i - 1])): + raise ValueError("bin_edges must be strictly increasing") + if math.isinf(float(bin_edges[-1])): + raise ValueError( + "PiecewiseExponentialCIFNLLLoss requires a finite last bin edge (no +inf). " + "Use a finite truncation horizon for PWE." + ) + + self.eps = float(eps) + # unused, kept only for interface compatibility + self.lambda_reg = float(lambda_reg) + + self.register_buffer( + "bin_edges", + torch.tensor([float(x) for x in bin_edges], dtype=torch.float32), + persistent=False, + ) + + def forward( + self, + logits: torch.Tensor, + target_events: torch.Tensor, + dt: torch.Tensor, + reduction: str = "mean", + ) -> Tuple[torch.Tensor, torch.Tensor]: + if reduction not in {"mean", "sum", "none"}: + raise ValueError("reduction must be one of {'mean','sum','none'}") + + if logits.ndim != 3: + raise ValueError( + f"logits must be 3D (M, K, n_bins); got shape={tuple(logits.shape)}") + if target_events.ndim != 1 or dt.ndim != 1: + raise ValueError("target_events and dt must be 1D tensors") + if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: + raise ValueError( + "Batch size mismatch among logits, target_events, dt") + + if not torch.all(dt > 0): + raise ValueError( + "dt must be strictly positive (no censoring supported here)") + + M, K, n_bins = logits.shape + + if target_events.dtype != torch.long: + target_events = target_events.to(torch.long) + if (target_events < 0).any() or (target_events >= K).any(): + raise ValueError(f"target_events must be in [0, {K-1}]") + + # Prepare bin_edges / bin widths + bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) + if bin_edges.numel() != n_bins + 1: + raise ValueError( + f"bin_edges length must be n_bins+1={n_bins+1}; got {bin_edges.numel()}" + ) + + dt_bins = (bin_edges[1:] - bin_edges[:-1] + ).to(device=logits.device, dtype=logits.dtype) # (n_bins,) + if not torch.isfinite(dt_bins).all(): + raise ValueError("All bin widths must be finite for PWE.") + if not (dt_bins > 0).all(): + raise ValueError( + "All bin widths must be strictly positive for PWE.") + + # Map event time -> bin index k* in {1..n_bins} + # (same convention as your discrete_time_cif: clamp to [1, n_bins]) + time_bin = torch.bucketize(dt, bin_edges) + time_bin = torch.clamp( + time_bin, min=1, max=n_bins).to(torch.long) # (M,) + k0 = time_bin - 1 # 0..n_bins-1 + + # Nonnegative hazards per cause per bin + hazards = F.softplus(logits) + self.eps # (M, K, n_bins) + + # Integrated hazards H_{j,k} = lambda_{j,k} * Δt_k + H_jk = hazards * dt_bins.view(1, 1, n_bins) # (M, K, n_bins) + H_k = H_jk.sum(dim=1) # (M, n_bins) + + # Previous survival term: Σ_{u 0.0 and n_bins >= 3: + log_h = torch.log(hazards) # (M, K, n_bins) + d2 = log_h[:, :, 2:] - 2.0 * log_h[:, :, 1:-1] + \ + log_h[:, :, :-2] # (M, K, n_bins-2) + reg = self.lambda_reg * (d2.pow(2).mean()) + else: + reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype) + + return nll, reg diff --git a/train.py b/train.py index 947bf1b..10fd470 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt +from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss, get_valid_pairs_and_dt from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from tqdm import tqdm @@ -22,7 +22,8 @@ from typing import Literal, Sequence class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' - loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential' + loss_type: Literal['exponential', + 'discrete_time_cif', 'pwe_cif'] = 'exponential' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 @@ -64,7 +65,7 @@ def parse_args() -> TrainConfig: parser.add_argument( "--loss_type", type=str, - choices=['exponential', 'discrete_time_cif'], + choices=['exponential', 'discrete_time_cif', 'pwe_cif'], default='exponential', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ @@ -205,6 +206,28 @@ class Trainer: ).to(self.device) # logits shape (M, K+1, n_bins+1) out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)] + elif cfg.loss_type == "pwe_cif": + # Piecewise-exponential (PWE) requires a FINITE last edge. + # If cfg.bin_edges ends with +inf (default), drop it and train up to the last finite edge. + pwe_edges = [float(x) + for x in cfg.bin_edges if math.isfinite(float(x))] + if len(pwe_edges) < 2: + raise ValueError( + "pwe_cif requires at least 2 finite bin edges (including 0). " + f"Got bin_edges={list(cfg.bin_edges)}" + ) + if pwe_edges[0] != 0.0: + raise ValueError( + f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}" + ) + + self.criterion = PiecewiseExponentialCIFNLLLoss( + bin_edges=pwe_edges, + lambda_reg=cfg.lambda_reg, + ).to(self.device) + n_bins = len(pwe_edges) - 1 + # logits shape (M, K, n_bins) + out_dims = [dataset.n_disease, n_bins] else: raise ValueError(f"Unsupported loss type: {cfg.loss_type}")