diff --git a/evaluate_models.py b/evaluate_models.py index 1e15dcc..7ed5cf7 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -360,7 +360,7 @@ def cifs_from_discrete_time_logits( logits: (B, K+1, n_bins+1) bin_edges: len=n_bins+1 (including 0 and inf) - taus: subset of finite bin edges (recommended) + taus: arbitrary non-negative horizons (need not align to bin edges) returns: (B, K, H) or (cif, survival) if return_survival """ @@ -392,31 +392,87 @@ def cifs_from_discrete_time_logits( cif_bins = torch.cumsum(s_prev.unsqueeze( 1) * hazards, dim=2) # (B,K,n_bins) - # Robust mapping from tau -> edge index (floating-point safe). - # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. finite_edges_arr = np.asarray(finite_edges, dtype=float) - tau_to_idx: List[int] = [] + if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins: + raise ValueError("Unexpected finite_edges shape") + if finite_edges_arr.size == 0: + raise ValueError("No finite bin edges provided") + if np.any(~np.isfinite(finite_edges_arr)): + raise ValueError("finite_edges must be finite") + if np.any(np.diff(finite_edges_arr) <= 0): + raise ValueError("finite bin edges must be strictly increasing") + + # For an arbitrary horizon tau, compute CIF/S(tau) as full bins before the + # bin containing tau, plus a partial-bin contribution. + # Assumption: within each bin, total hazard is constant and cause-specific + # hazards are proportional. + u_list: List[int] = [] + frac_list: List[float] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") - diffs = np.abs(finite_edges_arr - tau_f) - j = int(np.argmin(diffs)) - if diffs[j] > 1e-6: - raise ValueError( - f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" - ) - tau_to_idx.append(j) + if tau_f < 0.0: + raise ValueError("taus must be non-negative") - idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) - cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) + if tau_f == 0.0: + u_list.append(0) + frac_list.append(0.0) + continue + + # u is the first bin whose end edge >= tau. + u = int(np.searchsorted(finite_edges_arr, tau_f, side="left")) + if u >= n_bins: + # Past the last finite edge: clamp to the last bin end. + u_list.append(n_bins - 1) + frac_list.append(1.0) + continue + + start = 0.0 if u == 0 else float(finite_edges_arr[u - 1]) + end = float(finite_edges_arr[u]) + width = end - start + if width <= 0.0: + raise ValueError("Invalid bin edges: non-positive bin width") + frac = (tau_f - start) / width + u_list.append(u) + frac_list.append(float(min(1.0, max(0.0, frac)))) + + u_idx = torch.tensor(u_list, device=logits.device, dtype=torch.long) # (H,) + frac_t = torch.tensor(frac_list, device=logits.device, + dtype=probs.dtype) # (H,) + + # CIF before the containing bin (pad with a leading zero for u==0) + zeros_cif = torch.zeros((B, K, 1), device=logits.device, dtype=probs.dtype) + cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,K,n_bins+1) + cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,K,H) + + # Survival before the containing bin + s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H) + + # Select per-bin parameters for the containing bin + hazards_u = hazards.index_select(dim=2, index=u_idx) # (B,K,H) + p_comp_u = p_comp.index_select(dim=1, index=u_idx) # (B,H) + + eps = 1e-12 + p_comp_u_clamped = torch.clamp(p_comp_u, min=eps, max=1.0) + h_total = -torch.log(p_comp_u_clamped) # (B,H) + + denom = 1.0 - p_comp_u + no_event = denom <= eps + denom_safe = torch.clamp(denom, min=eps) + ratio = hazards_u / denom_safe.unsqueeze(1) + ratio = torch.where(no_event.unsqueeze(1), torch.zeros_like(ratio), ratio) + + h_partial = h_total * frac_t.unsqueeze(0) # (B,H) + one_minus_partial = -torch.expm1(-h_partial) # (B,H) + p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,K,H) + + cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,K,H) if not return_survival: return cif - # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] - survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] - survival = survival_bins.index_select(dim=1, index=idx) # (B,H) + survival = s_u * torch.exp(-h_partial) # (B,H) return cif, survival @@ -439,7 +495,7 @@ def cifs_from_lognormal_basis_binned_hazard_logits( """Convert Route-3 binned hazard logits -> CIFs at taus. logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored). - taus are expected to align with finite bin edges. + taus may be any non-negative horizons (need not align to bin edges). """ if logits.ndim not in {2, 3}: raise ValueError("logits must be 2D or 3D") @@ -537,26 +593,70 @@ def cifs_from_lognormal_basis_binned_hazard_logits( 1) * p_event, dim=2) # (B,J,n_bins) finite_edges_arr = np.asarray(finite_edges, dtype=float) - tau_to_idx: List[int] = [] + if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins: + raise ValueError("Unexpected finite_edges shape") + if finite_edges_arr.size == 0: + raise ValueError("No finite bin edges provided") + if np.any(~np.isfinite(finite_edges_arr)): + raise ValueError("finite_edges must be finite") + if np.any(np.diff(finite_edges_arr) <= 0): + raise ValueError("finite bin edges must be strictly increasing") + + u_list: List[int] = [] + frac_list: List[float] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") - diffs = np.abs(finite_edges_arr - tau_f) - idx0 = int(np.argmin(diffs)) - if diffs[idx0] > 1e-6: - raise ValueError( - f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})" - ) - tau_to_idx.append(idx0) + if tau_f < 0.0: + raise ValueError("taus must be non-negative") - idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long) - cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H) + if tau_f == 0.0: + u_list.append(0) + frac_list.append(0.0) + continue + + u = int(np.searchsorted(finite_edges_arr, tau_f, side="left")) + if u >= n_bins: + u_list.append(n_bins - 1) + frac_list.append(1.0) + continue + + start = 0.0 if u == 0 else float(finite_edges_arr[u - 1]) + end = float(finite_edges_arr[u]) + width = end - start + if width <= 0.0: + raise ValueError("Invalid bin edges: non-positive bin width") + frac = (tau_f - start) / width + u_list.append(u) + frac_list.append(float(min(1.0, max(0.0, frac)))) + + u_idx = torch.tensor(u_list, device=device, dtype=torch.long) # (H,) + frac_t = torch.tensor(frac_list, device=device, dtype=dtype) # (H,) + + zeros_cif = torch.zeros((alpha.size(0), j, 1), device=device, dtype=dtype) + cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,J,n_bins+1) + cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,J,H) + + s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H) + h_total = h_k.index_select(dim=1, index=u_idx) # (B,H) + h_j_total = h_jk.index_select(dim=2, index=u_idx) # (B,J,H) + + h_total_safe = torch.clamp(h_total, min=eps) + ratio = h_j_total / h_total_safe.unsqueeze(1) + ratio = torch.where(h_total.unsqueeze(1) <= eps, + torch.zeros_like(ratio), ratio) + + h_partial = h_total * frac_t.unsqueeze(0) # (B,H) + one_minus_partial = -torch.expm1(-h_partial) # (B,H) + p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,J,H) + + cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,J,H) if not return_survival: return cif - survival = cum.index_select(dim=1, index=idx) # (B,H) + survival = s_u * torch.exp(-h_partial) # (B,H) return cif, survival