diff --git a/evaluate_models.py b/evaluate_models.py index 0bfa9b3..b506558 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256 class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi - loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard + loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif full_cov: bool checkpoint_path: str @@ -425,24 +425,24 @@ def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor: return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) -def cifs_from_lognormal_basis_logits( +def cifs_from_lognormal_basis_binned_hazard_logits( logits: torch.Tensor, + *, centers: Sequence[float], sigma: torch.Tensor, + bin_edges: Sequence[float], taus: Sequence[float], - *, - bin_edges: Optional[Sequence[float]] = None, + eps: float = 1e-8, + alpha_floor: float = 0.0, return_survival: bool = False, ) -> torch.Tensor: - """Convert LogNormalBasisHazardLoss logits -> CIFs at taus. + """Convert Route-3 binned hazard logits -> CIFs at taus. - logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions. - centers: length R, in log-time. - sigma: scalar tensor (already clamped) in log-time units. - taus: horizons in the same units as training bin_edges (years). + logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored). + taus are expected to align with finite bin edges. """ - if logits.ndim != 2: - raise ValueError("Expected logits shape (B, 1+K*R)") + if logits.ndim not in {2, 3}: + raise ValueError("logits must be 2D or 3D") if sigma.ndim != 0: raise ValueError("sigma must be a scalar tensor") @@ -452,37 +452,111 @@ def cifs_from_lognormal_basis_logits( centers_t = torch.tensor([float(x) for x in centers], device=device, dtype=dtype) r = int(centers_t.numel()) - jr = int(logits.shape[1] - 1) - if jr <= 0 or (jr % r) != 0: - raise ValueError("logits.shape[1]-1 must be divisible by R") - k = jr // r + if r <= 0: + raise ValueError("centers must be non-empty") + + offset = 0 + if logits.ndim == 3: + j = int(logits.shape[1]) + if int(logits.shape[2]) != r: + raise ValueError( + f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}" + ) + else: + d = int(logits.shape[1]) + if d % r == 0: + jr = d + elif (d - 1) % r == 0: + offset = 1 + jr = d - 1 + else: + raise ValueError( + f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}") + + j = jr // r + + if j <= 0: + raise ValueError("Inferred J must be >= 1") + + edges = [float(x) for x in bin_edges] + finite_edges = [e for e in edges[1:] if math.isfinite(e)] + n_bins = len(finite_edges) + if n_bins <= 0: + raise ValueError("bin_edges must contain at least one finite edge") + + # Build finite bins [edges[k-1], edges[k]) for k=1..n_bins + left = torch.tensor(edges[:n_bins], device=device, dtype=dtype) + right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype) # Stable t_min clamp (aligns with training loss rule). t_min = 1e-12 - if bin_edges is not None: - edges = [float(x) for x in bin_edges] - if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0: - t_min = edges[1] * 1e-6 + if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0: + t_min = edges[1] * 1e-6 t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype) - taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype) - taus_t = torch.clamp(taus_t, min=t_min_t) - log_tau = torch.log(taus_t) # (H,) + left_is_zero = left <= 0 + left_clamped = torch.clamp(left, min=t_min_t) + log_left = torch.log(left_clamped) + right_clamped = torch.clamp(right, min=t_min_t) + log_right = torch.log(right_clamped) - # (H,R) - z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma - cdf = _normal_cdf_stable(z) + sigma_c = sigma.to(device=device, dtype=dtype) + z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c + z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c - w_all = torch.softmax(logits, dim=-1) - w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R) + cdf_left = _normal_cdf_stable(z_left) + if left_is_zero.any(): + cdf_left = torch.where(left_is_zero.unsqueeze(-1), + torch.zeros_like(cdf_left), cdf_left) + cdf_right = _normal_cdf_stable(z_right) + delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R) - cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H) + if logits.ndim == 3: + alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R) + else: + logits_used = logits[:, offset:] + alpha = (F.softplus(logits_used) + float(alpha_floor) + ).view(logits.size(0), j, r) # (B,J,R) + + h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins) + h_k = h_jk.sum(dim=1) # (B,n_bins) + + h_k = torch.clamp(h_k, min=eps) + h_jk = torch.clamp(h_jk, min=eps) + + p_comp = torch.exp(-h_k) # (B,n_bins) + one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H) + ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps) + p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins) + + ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype) + cum = torch.cumprod(p_comp, dim=1) # survival after each bin + s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin + + cif_bins = torch.cumsum(s_prev.unsqueeze( + 1) * p_event, dim=2) # (B,J,n_bins) + + finite_edges_arr = np.asarray(finite_edges, dtype=float) + tau_to_idx: List[int] = [] + for tau in taus: + tau_f = float(tau) + if not math.isfinite(tau_f): + raise ValueError("taus must be finite for discrete-time CIF") + diffs = np.abs(finite_edges_arr - tau_f) + idx0 = int(np.argmin(diffs)) + if diffs[idx0] > 1e-6: + raise ValueError( + f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})" + ) + tau_to_idx.append(idx0) + + idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long) + cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H) if not return_survival: return cif - survival = 1.0 - cif.sum(dim=1) # (B,H) - survival = torch.clamp(survival, min=0.0, max=1.0) + survival = cum.index_select(dim=1, index=idx) # (B,H) return cif, survival @@ -1269,20 +1343,46 @@ def instantiate_model_and_head( elif loss_type == "discrete_time_cif": bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] - elif loss_type == "lognormal_basis_hazard": + elif loss_type == "lognormal_basis_binned_hazard_cif": centers = cfg.get("lognormal_centers", None) if centers is None: centers = cfg.get("centers", None) if not isinstance(centers, list) or len(centers) == 0: raise ValueError( - "lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json" + "lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json" ) - out_dims = [1 + dataset.n_disease * len(centers)] + r = len(centers) + desired_total = int(dataset.n_disease) * int(r) + legacy_total = 1 + desired_total + + # Prefer the new shape (K,R) but keep compatibility with older checkpoints + # that used a single flattened dimension (1 + K*R). + out_dims = [int(dataset.n_disease), int(r)] + if checkpoint_path: + try: + ckpt = torch.load(checkpoint_path, map_location="cpu") + head_sd = ckpt.get("head_state_dict", {}) + w = head_sd.get("net.2.weight", None) + if isinstance(w, torch.Tensor) and w.ndim == 2: + out_features = int(w.shape[0]) + if out_features == legacy_total: + out_dims = [legacy_total] + elif out_features == desired_total: + out_dims = [int(dataset.n_disease), int(r)] + else: + raise ValueError( + f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)" + ) + except Exception as e: + raise ValueError( + f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}" + ) loss_params["centers"] = centers loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3)) loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0)) loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7)) loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8)) + loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0)) else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") @@ -1399,18 +1499,20 @@ def predict_cifs_for_model( cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) - elif loss_type == "lognormal_basis_hazard": + elif loss_type == "lognormal_basis_binned_hazard_cif": centers = loss_params.get("centers", None) sigma = loss_params.get("sigma", None) if centers is None or sigma is None: raise ValueError( - "lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']") - cif_full, survival = cifs_from_lognormal_basis_logits( + "lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']") + cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits( logits, centers=centers, sigma=sigma, - taus=eval_horizons, bin_edges=bin_edges, + taus=eval_horizons, + eps=float(loss_params.get("loss_eps", 1e-8)), + alpha_floor=float(loss_params.get("alpha_floor", 0.0)), return_survival=True, ) else: @@ -1927,7 +2029,7 @@ def main() -> int: backbone.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) - if loss_type == "lognormal_basis_hazard": + if loss_type == "lognormal_basis_binned_hazard_cif": crit_state = ckpt.get("criterion_state_dict", {}) log_sigma = crit_state.get("log_sigma", None) if isinstance(log_sigma, torch.Tensor): diff --git a/losses.py b/losses.py index e251e4d..6029d4f 100644 --- a/losses.py +++ b/losses.py @@ -260,24 +260,27 @@ class DiscreteTimeCIFNLLLoss(nn.Module): return nll, reg -class LogNormalBasisHazardLoss(nn.Module): - """Event-only competing risks loss using lognormal basis (Gaussian on log-time). +class LogNormalBasisBinnedHazardCIFNLLLoss(nn.Module): + r"""Route-3: continuous-time lognormal-basis hazards with discrete-time CIF likelihood. - This loss models cause-specific CIF as a mixture of lognormal basis CDFs: + This implements a cause-specific continuous-time hazard model: - F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma) + \lambda_j(t) = \sum_r \alpha_{j,r} b_r(t) - Training uses *bin probability mass* (Delta CIF / interval mass). There is - **no censoring**: every sample is an observed event with a valid cause. + where b_r(t) is the lognormal PDF basis implied by a Normal on log-time. - Logits interface: - logits: (B, 1 + J*R) - logits[:, 0] -> w0 (survival mass / never-event) - logits[:, 1:] -> flattened (j,r) in row-major order: j then r - index = 1 + j*R + r + Training objective is IDENTICAL in structure to DiscreteTimeCIFNLLLoss, + but per-bin categorical probabilities are derived from integrated hazards. + + Expected logits interface (preferred): + logits: (B, J*R) + reshaped to (B, J, R) + + For convenience/compatibility, also accepts: + logits: (B, 1+J*R) and ignores the first column. Forward interface (must match): - forward(logits, target_events, dt, reduction) + forward(logits, target_events, dt, reduction) -> (nll, reg) """ def __init__( @@ -286,17 +289,20 @@ class LogNormalBasisHazardLoss(nn.Module): centers: Sequence[float], *, eps: float = 1e-8, - bandwidth_init: float = 0.5, + alpha_floor: float = 0.0, + bandwidth_init: float = 0.7, bandwidth_min: float = 1e-3, bandwidth_max: float = 10.0, lambda_sigma_reg: float = 0.0, sigma_reg_target: Optional[float] = None, - return_dict: bool = False, + lambda_reg: float = 0.0, ): super().__init__() if len(bin_edges) < 2: - raise ValueError("bin_edges must have length >= 2") + raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") + if float(bin_edges[0]) != 0.0: + raise ValueError("bin_edges[0] must equal 0") # allow last edge to be +inf for i in range(1, len(bin_edges)): prev = float(bin_edges[i - 1]) @@ -310,20 +316,19 @@ class LogNormalBasisHazardLoss(nn.Module): else: if not (cur > prev): raise ValueError("bin_edges must be strictly increasing") - if float(bin_edges[0]) < 0.0: - raise ValueError("bin_edges[0] must be >= 0") if len(centers) < 1: raise ValueError("centers must have length >= 1") self.eps = float(eps) + self.alpha_floor = float(alpha_floor) self.bandwidth_min = float(bandwidth_min) self.bandwidth_max = float(bandwidth_max) + self.bandwidth_init = float(bandwidth_init) self.lambda_sigma_reg = float(lambda_sigma_reg) self.sigma_reg_target = None if sigma_reg_target is None else float( sigma_reg_target) - self.bandwidth_init = float(bandwidth_init) - self.return_dict = bool(return_dict) + self.lambda_reg = float(lambda_reg) self.register_buffer( "bin_edges", @@ -343,69 +348,32 @@ class LogNormalBasisHazardLoss(nn.Module): @staticmethod def _normal_cdf(z: torch.Tensor) -> torch.Tensor: - # Stable normal CDF via erf. z = torch.clamp(z, -12.0, 12.0) return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) @staticmethod def _normal_sf(z: torch.Tensor) -> torch.Tensor: - # Stable normal survival function via erfc. z = torch.clamp(z, -12.0, 12.0) return 0.5 * torch.erfc(z / math.sqrt(2.0)) - def forward( + def _compute_delta_basis_all_bins( self, - logits: torch.Tensor, - target_events: torch.Tensor, - dt: torch.Tensor, - reduction: str = "mean", - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]: - if logits.ndim != 2: - raise ValueError( - f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}") - if target_events.ndim != 1 or dt.ndim != 1: - raise ValueError("target_events and dt must be 1D tensors") - if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: - raise ValueError( - "Batch size mismatch among logits, target_events, dt") - if reduction not in {"mean", "sum", "none"}: - raise ValueError("reduction must be one of {'mean','sum','none'}") - - device = logits.device - dtype = logits.dtype + *, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Compute ΔB[k,r] for bins k=1..n_bins (shape: (n_bins, R)).""" bin_edges = self.bin_edges.to(device=device, dtype=dtype) centers = self.centers.to(device=device, dtype=dtype) - bsz = logits.shape[0] - r = int(centers.numel()) - jr = int(logits.shape[1] - 1) - if jr <= 0: - raise ValueError( - "logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)") - if jr % r != 0: - raise ValueError( - f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}") - j = jr // r - # 1) Stable global weights (includes w0). - w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R) - w0 = w_all[:, 0] - w = w_all[:, 1:].view(bsz, j, r) - - # 2) Determine event bin index. - k = int(bin_edges.numel() - 1) - if k < 1: + n_bins = int(bin_edges.numel() - 1) + if n_bins < 1: raise ValueError("bin_edges must define at least one bin") - # v2: dt is always continuous time (float), map to bin via searchsorted. - dt_f = dt.to(device=device, dtype=dtype) - bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1 - bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long) + left = bin_edges[:-1] # (n_bins,) + right = bin_edges[1:] # (n_bins,) - left = bin_edges[bin_idx] - right = bin_edges[bin_idx + 1] - - # 3) Stable log(t) clamp. if float(self.bin_edges[1]) > 0.0: t_min = float(self.bin_edges[1]) * 1e-6 else: @@ -413,19 +381,19 @@ class LogNormalBasisHazardLoss(nn.Module): t_min_t = torch.tensor(t_min, device=device, dtype=dtype) left_is_zero = left <= 0 - - # For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly - # when left<=0 (instead of approximating via t_min). left_clamped = torch.clamp(left, min=t_min_t) log_left = torch.log(left_clamped) + is_inf = torch.isinf(right) - # right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value. - right_safe = torch.where(is_inf, left_clamped, - torch.clamp(right, min=t_min_t)) + right_safe = torch.where( + is_inf, left_clamped, torch.clamp(right, min=t_min_t)) log_right = torch.log(right_safe) - sigma = torch.clamp(self.log_sigma.to( - device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max) + sigma = torch.clamp( + self.log_sigma.to(device=device, dtype=dtype).exp(), + self.bandwidth_min, + self.bandwidth_max, + ) z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma @@ -433,58 +401,152 @@ class LogNormalBasisHazardLoss(nn.Module): z_right = torch.clamp(z_right, -12.0, 12.0) cdf_left = self._normal_cdf(z_left) - # Treat the first-bin left boundary exactly as 0 in CDF. if left_is_zero.any(): cdf_left = torch.where( left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left) + cdf_right = self._normal_cdf(z_right) delta_finite = cdf_right - cdf_left + + # Last bin: ΔB = 1 - CDF(left) = SF(left), computed via erfc for stability. delta_last = self._normal_sf(z_left) - # If left<=0, SF(left)=1 exactly. if left_is_zero.any(): delta_last = torch.where( left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last) + delta_basis = torch.where( is_inf.unsqueeze(-1), delta_last, delta_finite) delta_basis = torch.clamp(delta_basis, min=0.0) + return delta_basis + + def forward( + self, + logits: torch.Tensor, + target_events: torch.Tensor, + dt: torch.Tensor, + reduction: str = "mean", + ) -> Tuple[torch.Tensor, torch.Tensor]: + if logits.ndim not in {2, 3}: + raise ValueError( + f"logits must be 2D (B, J*R) (or (B, 1+J*R)) or 3D (B, J, R); got {tuple(logits.shape)}" + ) + if target_events.ndim != 1 or dt.ndim != 1: + raise ValueError( + f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}" + ) + if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: + raise ValueError( + "Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]" + ) + if reduction not in {"mean", "sum", "none"}: + raise ValueError("reduction must be one of {'mean','sum','none'}") + + if not torch.all(dt > 0): + raise ValueError("dt must be strictly positive") + + device = logits.device + dtype = logits.dtype + + centers = self.centers.to(device=device, dtype=dtype) + r = int(centers.numel()) + if r < 1: + raise ValueError("centers must have length >= 1") + + if logits.ndim == 3: + if logits.shape[2] != r: + raise ValueError( + f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}" + ) + j = int(logits.shape[1]) + if j < 1: + raise ValueError("Inferred number of causes J must be >= 1") + alpha = F.softplus(logits) + self.alpha_floor # (B, J, R) + else: + d = int(logits.shape[1]) + offset = 0 + if d % r == 0: + jr = d + elif (d - 1) % r == 0: + offset = 1 + jr = d - 1 + else: + raise ValueError( + f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}" + ) + + j = jr // r + if j < 1: + raise ValueError("Inferred number of causes J must be >= 1") + + logits_used = logits[:, offset:] + alpha = F.softplus(logits_used).view(-1, j, r) + \ + self.alpha_floor # (B, J, R) + + delta_basis = self._compute_delta_basis_all_bins( + device=device, + dtype=dtype, + ) # (n_bins, R) + n_bins = int(delta_basis.shape[0]) + + # H_{j,k} = sum_r alpha_{j,r} * ΔB_{k,r} + h_jk = torch.einsum("mjr,kr->mjk", alpha, + delta_basis) # (B, J, n_bins) + h_k = h_jk.sum(dim=1) # (B, n_bins) + + # Map continuous dt to discrete event bin index k* in {1..n_bins}. + bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) + time_bin = torch.bucketize(dt, bin_edges) + time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(torch.long) - # 4) Gather per-sample cause weights and compute event mass. cause = target_events.to(device=device, dtype=torch.long) if (cause < 0).any() or (cause >= j).any(): raise ValueError(f"target_events must be in [0, J-1] where J={j}") - b_idx = torch.arange(bsz, device=device) - w_cause = w[b_idx, cause, :] # (B, R) + # Previous survival term: sum_{u 0.0: + # Regularize the within-bin cause competition via NLL on log ratios. + # ratio_j = H_{j,k*} / H_{k*} + h_event_all = h_jk[b_idx, :, k0] # (B, J) + denom = torch.clamp(h_event_total, min=self.eps).unsqueeze(1) + ratio = torch.clamp(h_event_all / denom, min=self.eps) + log_ratio = torch.log(ratio) + reg = reg + self.lambda_reg * F.nll_loss( + log_ratio, cause, reduction="mean") - sigma_penalty = torch.zeros((), device=device, dtype=dtype) if self.lambda_sigma_reg > 0.0: target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target sigma_penalty = (self.log_sigma.to( device=device, dtype=dtype) - math.log(float(target))) ** 2 - reg = sigma_penalty * float(self.lambda_sigma_reg) + reg = reg + sigma_penalty * self.lambda_sigma_reg - if not self.return_dict: - return nll, reg - - return { - "nll": nll, - "reg": reg, - "nll_vec": nll_vec, - "sigma": sigma.detach(), - "avg_w0": w0.mean().detach(), - "min_delta_basis": delta_basis.min().detach(), - "mean_m": m.mean().detach(), - "sigma_penalty": sigma_penalty.detach(), - "bin_idx": bin_idx.detach(), - } + return nll, reg diff --git a/train.py b/train.py index 15deed4..183424f 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt +from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from tqdm import tqdm @@ -22,7 +22,7 @@ class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' loss_type: Literal['exponential', 'discrete_time_cif', - 'lognormal_basis_hazard'] = 'exponential' + 'lognormal_basis_binned_hazard_cif'] = 'exponential' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 @@ -34,7 +34,7 @@ class TrainConfig: default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')] ) - # LogNormalBasisHazardLoss specific + # LogNormal basis (shared by Route-3 binned hazard) lognormal_centers: Optional[Sequence[float]] = field( default_factory=list) # mu_r in log-time loss_eps: float = 1e-8 @@ -73,7 +73,8 @@ def parse_args() -> TrainConfig: parser.add_argument( "--loss_type", type=str, - choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'], + choices=['exponential', 'discrete_time_cif', + 'lognormal_basis_binned_hazard_cif'], default='exponential', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ @@ -93,17 +94,17 @@ def parse_args() -> TrainConfig: type=float, nargs='*', default=None, - help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.") + help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.") parser.add_argument("--loss_eps", type=float, default=1e-8, - help="Epsilon for LogNormalBasisHazardLoss log clamp.") + help="Epsilon for log clamps in lognormal-basis losses.") parser.add_argument("--bandwidth_init", type=float, default=0.7, - help="Initial sigma for LogNormalBasisHazardLoss.") + help="Initial sigma for lognormal-basis.") parser.add_argument("--bandwidth_min", type=float, default=1e-3, - help="Minimum sigma clamp for LogNormalBasisHazardLoss.") + help="Minimum sigma clamp for lognormal-basis.") parser.add_argument("--bandwidth_max", type=float, default=10.0, - help="Maximum sigma clamp for LogNormalBasisHazardLoss.") + help="Maximum sigma clamp for lognormal-basis.") parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4, - help="Sigma regularization strength for LogNormalBasisHazardLoss.") + help="Sigma regularization strength for lognormal-basis.") parser.add_argument("--sigma_reg_target", type=float, default=None, help="Optional sigma target for regularization (otherwise uses bandwidth_init).") parser.add_argument("--rank", type=int, default=16, @@ -261,12 +262,12 @@ class Trainer: ).to(self.device) # logits shape (M, K+1, n_bins+1) out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)] - elif cfg.loss_type == "lognormal_basis_hazard": + elif cfg.loss_type == "lognormal_basis_binned_hazard_cif": r = len(cfg.lognormal_centers) if r <= 0: raise ValueError( - "lognormal_centers must be non-empty for lognormal_basis_hazard") - self.criterion = LogNormalBasisHazardLoss( + "lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif") + self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss( bin_edges=cfg.bin_edges, centers=cfg.lognormal_centers, eps=cfg.loss_eps, @@ -275,9 +276,10 @@ class Trainer: bandwidth_max=cfg.bandwidth_max, lambda_sigma_reg=cfg.lambda_sigma_reg, sigma_reg_target=cfg.sigma_reg_target, + lambda_reg=cfg.lambda_reg, ).to(self.device) - # logits shape (M, 1 + J*R) - out_dims = [1 + dataset.n_disease * r] + # Head emits (M, J, R) for Route-3. + out_dims = [dataset.n_disease, r] else: raise ValueError(f"Unsupported loss type: {cfg.loss_type}")