Refactor CIF computation to support arbitrary non-negative horizons and improve error handling for finite bin edges

This commit is contained in:
2026-01-14 11:29:40 +08:00
parent f121574872
commit 2170b61d87

View File

@@ -360,7 +360,7 @@ def cifs_from_discrete_time_logits(
logits: (B, K+1, n_bins+1)
bin_edges: len=n_bins+1 (including 0 and inf)
taus: subset of finite bin edges (recommended)
taus: arbitrary non-negative horizons (need not align to bin edges)
returns: (B, K, H) or (cif, survival) if return_survival
"""
@@ -392,31 +392,87 @@ def cifs_from_discrete_time_logits(
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * hazards, dim=2) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
raise ValueError("Unexpected finite_edges shape")
if finite_edges_arr.size == 0:
raise ValueError("No finite bin edges provided")
if np.any(~np.isfinite(finite_edges_arr)):
raise ValueError("finite_edges must be finite")
if np.any(np.diff(finite_edges_arr) <= 0):
raise ValueError("finite bin edges must be strictly increasing")
# For an arbitrary horizon tau, compute CIF/S(tau) as full bins before the
# bin containing tau, plus a partial-bin contribution.
# Assumption: within each bin, total hazard is constant and cause-specific
# hazards are proportional.
u_list: List[int] = []
frac_list: List[float] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
if tau_f < 0.0:
raise ValueError("taus must be non-negative")
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if tau_f == 0.0:
u_list.append(0)
frac_list.append(0.0)
continue
# u is the first bin whose end edge >= tau.
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
if u >= n_bins:
# Past the last finite edge: clamp to the last bin end.
u_list.append(n_bins - 1)
frac_list.append(1.0)
continue
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
end = float(finite_edges_arr[u])
width = end - start
if width <= 0.0:
raise ValueError("Invalid bin edges: non-positive bin width")
frac = (tau_f - start) / width
u_list.append(u)
frac_list.append(float(min(1.0, max(0.0, frac))))
u_idx = torch.tensor(u_list, device=logits.device, dtype=torch.long) # (H,)
frac_t = torch.tensor(frac_list, device=logits.device,
dtype=probs.dtype) # (H,)
# CIF before the containing bin (pad with a leading zero for u==0)
zeros_cif = torch.zeros((B, K, 1), device=logits.device, dtype=probs.dtype)
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,K,n_bins+1)
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,K,H)
# Survival before the containing bin
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
# Select per-bin parameters for the containing bin
hazards_u = hazards.index_select(dim=2, index=u_idx) # (B,K,H)
p_comp_u = p_comp.index_select(dim=1, index=u_idx) # (B,H)
eps = 1e-12
p_comp_u_clamped = torch.clamp(p_comp_u, min=eps, max=1.0)
h_total = -torch.log(p_comp_u_clamped) # (B,H)
denom = 1.0 - p_comp_u
no_event = denom <= eps
denom_safe = torch.clamp(denom, min=eps)
ratio = hazards_u / denom_safe.unsqueeze(1)
ratio = torch.where(no_event.unsqueeze(1), torch.zeros_like(ratio), ratio)
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,K,H)
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
survival = s_u * torch.exp(-h_partial) # (B,H)
return cif, survival
@@ -439,7 +495,7 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
"""Convert Route-3 binned hazard logits -> CIFs at taus.
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
taus are expected to align with finite bin edges.
taus may be any non-negative horizons (need not align to bin edges).
"""
if logits.ndim not in {2, 3}:
raise ValueError("logits must be 2D or 3D")
@@ -537,26 +593,70 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
1) * p_event, dim=2) # (B,J,n_bins)
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
raise ValueError("Unexpected finite_edges shape")
if finite_edges_arr.size == 0:
raise ValueError("No finite bin edges provided")
if np.any(~np.isfinite(finite_edges_arr)):
raise ValueError("finite_edges must be finite")
if np.any(np.diff(finite_edges_arr) <= 0):
raise ValueError("finite bin edges must be strictly increasing")
u_list: List[int] = []
frac_list: List[float] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
idx0 = int(np.argmin(diffs))
if diffs[idx0] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
)
tau_to_idx.append(idx0)
if tau_f < 0.0:
raise ValueError("taus must be non-negative")
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
if tau_f == 0.0:
u_list.append(0)
frac_list.append(0.0)
continue
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
if u >= n_bins:
u_list.append(n_bins - 1)
frac_list.append(1.0)
continue
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
end = float(finite_edges_arr[u])
width = end - start
if width <= 0.0:
raise ValueError("Invalid bin edges: non-positive bin width")
frac = (tau_f - start) / width
u_list.append(u)
frac_list.append(float(min(1.0, max(0.0, frac))))
u_idx = torch.tensor(u_list, device=device, dtype=torch.long) # (H,)
frac_t = torch.tensor(frac_list, device=device, dtype=dtype) # (H,)
zeros_cif = torch.zeros((alpha.size(0), j, 1), device=device, dtype=dtype)
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,J,n_bins+1)
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,J,H)
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
h_total = h_k.index_select(dim=1, index=u_idx) # (B,H)
h_j_total = h_jk.index_select(dim=2, index=u_idx) # (B,J,H)
h_total_safe = torch.clamp(h_total, min=eps)
ratio = h_j_total / h_total_safe.unsqueeze(1)
ratio = torch.where(h_total.unsqueeze(1) <= eps,
torch.zeros_like(ratio), ratio)
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,J,H)
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,J,H)
if not return_survival:
return cif
survival = cum.index_select(dim=1, index=idx) # (B,H)
survival = s_u * torch.exp(-h_partial) # (B,H)
return cif, survival