Refactor CIF computation to support arbitrary non-negative horizons and improve error handling for finite bin edges
This commit is contained in:
@@ -360,7 +360,7 @@ def cifs_from_discrete_time_logits(
|
||||
|
||||
logits: (B, K+1, n_bins+1)
|
||||
bin_edges: len=n_bins+1 (including 0 and inf)
|
||||
taus: subset of finite bin edges (recommended)
|
||||
taus: arbitrary non-negative horizons (need not align to bin edges)
|
||||
|
||||
returns: (B, K, H) or (cif, survival) if return_survival
|
||||
"""
|
||||
@@ -392,31 +392,87 @@ def cifs_from_discrete_time_logits(
|
||||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||||
1) * hazards, dim=2) # (B,K,n_bins)
|
||||
|
||||
# Robust mapping from tau -> edge index (floating-point safe).
|
||||
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
|
||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||
tau_to_idx: List[int] = []
|
||||
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
|
||||
raise ValueError("Unexpected finite_edges shape")
|
||||
if finite_edges_arr.size == 0:
|
||||
raise ValueError("No finite bin edges provided")
|
||||
if np.any(~np.isfinite(finite_edges_arr)):
|
||||
raise ValueError("finite_edges must be finite")
|
||||
if np.any(np.diff(finite_edges_arr) <= 0):
|
||||
raise ValueError("finite bin edges must be strictly increasing")
|
||||
|
||||
# For an arbitrary horizon tau, compute CIF/S(tau) as full bins before the
|
||||
# bin containing tau, plus a partial-bin contribution.
|
||||
# Assumption: within each bin, total hazard is constant and cause-specific
|
||||
# hazards are proportional.
|
||||
u_list: List[int] = []
|
||||
frac_list: List[float] = []
|
||||
for tau in taus:
|
||||
tau_f = float(tau)
|
||||
if not math.isfinite(tau_f):
|
||||
raise ValueError("taus must be finite for discrete-time CIF")
|
||||
diffs = np.abs(finite_edges_arr - tau_f)
|
||||
j = int(np.argmin(diffs))
|
||||
if diffs[j] > 1e-6:
|
||||
raise ValueError(
|
||||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
|
||||
)
|
||||
tau_to_idx.append(j)
|
||||
if tau_f < 0.0:
|
||||
raise ValueError("taus must be non-negative")
|
||||
|
||||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
||||
if tau_f == 0.0:
|
||||
u_list.append(0)
|
||||
frac_list.append(0.0)
|
||||
continue
|
||||
|
||||
# u is the first bin whose end edge >= tau.
|
||||
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
|
||||
if u >= n_bins:
|
||||
# Past the last finite edge: clamp to the last bin end.
|
||||
u_list.append(n_bins - 1)
|
||||
frac_list.append(1.0)
|
||||
continue
|
||||
|
||||
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
|
||||
end = float(finite_edges_arr[u])
|
||||
width = end - start
|
||||
if width <= 0.0:
|
||||
raise ValueError("Invalid bin edges: non-positive bin width")
|
||||
frac = (tau_f - start) / width
|
||||
u_list.append(u)
|
||||
frac_list.append(float(min(1.0, max(0.0, frac))))
|
||||
|
||||
u_idx = torch.tensor(u_list, device=logits.device, dtype=torch.long) # (H,)
|
||||
frac_t = torch.tensor(frac_list, device=logits.device,
|
||||
dtype=probs.dtype) # (H,)
|
||||
|
||||
# CIF before the containing bin (pad with a leading zero for u==0)
|
||||
zeros_cif = torch.zeros((B, K, 1), device=logits.device, dtype=probs.dtype)
|
||||
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,K,n_bins+1)
|
||||
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,K,H)
|
||||
|
||||
# Survival before the containing bin
|
||||
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
|
||||
|
||||
# Select per-bin parameters for the containing bin
|
||||
hazards_u = hazards.index_select(dim=2, index=u_idx) # (B,K,H)
|
||||
p_comp_u = p_comp.index_select(dim=1, index=u_idx) # (B,H)
|
||||
|
||||
eps = 1e-12
|
||||
p_comp_u_clamped = torch.clamp(p_comp_u, min=eps, max=1.0)
|
||||
h_total = -torch.log(p_comp_u_clamped) # (B,H)
|
||||
|
||||
denom = 1.0 - p_comp_u
|
||||
no_event = denom <= eps
|
||||
denom_safe = torch.clamp(denom, min=eps)
|
||||
ratio = hazards_u / denom_safe.unsqueeze(1)
|
||||
ratio = torch.where(no_event.unsqueeze(1), torch.zeros_like(ratio), ratio)
|
||||
|
||||
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
|
||||
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
|
||||
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,K,H)
|
||||
|
||||
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,K,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
|
||||
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
|
||||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
||||
survival = s_u * torch.exp(-h_partial) # (B,H)
|
||||
return cif, survival
|
||||
|
||||
|
||||
@@ -439,7 +495,7 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
|
||||
"""Convert Route-3 binned hazard logits -> CIFs at taus.
|
||||
|
||||
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
|
||||
taus are expected to align with finite bin edges.
|
||||
taus may be any non-negative horizons (need not align to bin edges).
|
||||
"""
|
||||
if logits.ndim not in {2, 3}:
|
||||
raise ValueError("logits must be 2D or 3D")
|
||||
@@ -537,26 +593,70 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
|
||||
1) * p_event, dim=2) # (B,J,n_bins)
|
||||
|
||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||
tau_to_idx: List[int] = []
|
||||
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
|
||||
raise ValueError("Unexpected finite_edges shape")
|
||||
if finite_edges_arr.size == 0:
|
||||
raise ValueError("No finite bin edges provided")
|
||||
if np.any(~np.isfinite(finite_edges_arr)):
|
||||
raise ValueError("finite_edges must be finite")
|
||||
if np.any(np.diff(finite_edges_arr) <= 0):
|
||||
raise ValueError("finite bin edges must be strictly increasing")
|
||||
|
||||
u_list: List[int] = []
|
||||
frac_list: List[float] = []
|
||||
for tau in taus:
|
||||
tau_f = float(tau)
|
||||
if not math.isfinite(tau_f):
|
||||
raise ValueError("taus must be finite for discrete-time CIF")
|
||||
diffs = np.abs(finite_edges_arr - tau_f)
|
||||
idx0 = int(np.argmin(diffs))
|
||||
if diffs[idx0] > 1e-6:
|
||||
raise ValueError(
|
||||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
|
||||
)
|
||||
tau_to_idx.append(idx0)
|
||||
if tau_f < 0.0:
|
||||
raise ValueError("taus must be non-negative")
|
||||
|
||||
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
|
||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
|
||||
if tau_f == 0.0:
|
||||
u_list.append(0)
|
||||
frac_list.append(0.0)
|
||||
continue
|
||||
|
||||
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
|
||||
if u >= n_bins:
|
||||
u_list.append(n_bins - 1)
|
||||
frac_list.append(1.0)
|
||||
continue
|
||||
|
||||
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
|
||||
end = float(finite_edges_arr[u])
|
||||
width = end - start
|
||||
if width <= 0.0:
|
||||
raise ValueError("Invalid bin edges: non-positive bin width")
|
||||
frac = (tau_f - start) / width
|
||||
u_list.append(u)
|
||||
frac_list.append(float(min(1.0, max(0.0, frac))))
|
||||
|
||||
u_idx = torch.tensor(u_list, device=device, dtype=torch.long) # (H,)
|
||||
frac_t = torch.tensor(frac_list, device=device, dtype=dtype) # (H,)
|
||||
|
||||
zeros_cif = torch.zeros((alpha.size(0), j, 1), device=device, dtype=dtype)
|
||||
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,J,n_bins+1)
|
||||
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,J,H)
|
||||
|
||||
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
|
||||
h_total = h_k.index_select(dim=1, index=u_idx) # (B,H)
|
||||
h_j_total = h_jk.index_select(dim=2, index=u_idx) # (B,J,H)
|
||||
|
||||
h_total_safe = torch.clamp(h_total, min=eps)
|
||||
ratio = h_j_total / h_total_safe.unsqueeze(1)
|
||||
ratio = torch.where(h_total.unsqueeze(1) <= eps,
|
||||
torch.zeros_like(ratio), ratio)
|
||||
|
||||
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
|
||||
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
|
||||
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,J,H)
|
||||
|
||||
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,J,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
survival = cum.index_select(dim=1, index=idx) # (B,H)
|
||||
survival = s_u * torch.exp(-h_partial) # (B,H)
|
||||
return cif, survival
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user