Refactor CIF computation to support arbitrary non-negative horizons and improve error handling for finite bin edges
This commit is contained in:
@@ -360,7 +360,7 @@ def cifs_from_discrete_time_logits(
|
|||||||
|
|
||||||
logits: (B, K+1, n_bins+1)
|
logits: (B, K+1, n_bins+1)
|
||||||
bin_edges: len=n_bins+1 (including 0 and inf)
|
bin_edges: len=n_bins+1 (including 0 and inf)
|
||||||
taus: subset of finite bin edges (recommended)
|
taus: arbitrary non-negative horizons (need not align to bin edges)
|
||||||
|
|
||||||
returns: (B, K, H) or (cif, survival) if return_survival
|
returns: (B, K, H) or (cif, survival) if return_survival
|
||||||
"""
|
"""
|
||||||
@@ -392,31 +392,87 @@ def cifs_from_discrete_time_logits(
|
|||||||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||||||
1) * hazards, dim=2) # (B,K,n_bins)
|
1) * hazards, dim=2) # (B,K,n_bins)
|
||||||
|
|
||||||
# Robust mapping from tau -> edge index (floating-point safe).
|
|
||||||
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
|
|
||||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||||
tau_to_idx: List[int] = []
|
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
|
||||||
|
raise ValueError("Unexpected finite_edges shape")
|
||||||
|
if finite_edges_arr.size == 0:
|
||||||
|
raise ValueError("No finite bin edges provided")
|
||||||
|
if np.any(~np.isfinite(finite_edges_arr)):
|
||||||
|
raise ValueError("finite_edges must be finite")
|
||||||
|
if np.any(np.diff(finite_edges_arr) <= 0):
|
||||||
|
raise ValueError("finite bin edges must be strictly increasing")
|
||||||
|
|
||||||
|
# For an arbitrary horizon tau, compute CIF/S(tau) as full bins before the
|
||||||
|
# bin containing tau, plus a partial-bin contribution.
|
||||||
|
# Assumption: within each bin, total hazard is constant and cause-specific
|
||||||
|
# hazards are proportional.
|
||||||
|
u_list: List[int] = []
|
||||||
|
frac_list: List[float] = []
|
||||||
for tau in taus:
|
for tau in taus:
|
||||||
tau_f = float(tau)
|
tau_f = float(tau)
|
||||||
if not math.isfinite(tau_f):
|
if not math.isfinite(tau_f):
|
||||||
raise ValueError("taus must be finite for discrete-time CIF")
|
raise ValueError("taus must be finite for discrete-time CIF")
|
||||||
diffs = np.abs(finite_edges_arr - tau_f)
|
if tau_f < 0.0:
|
||||||
j = int(np.argmin(diffs))
|
raise ValueError("taus must be non-negative")
|
||||||
if diffs[j] > 1e-6:
|
|
||||||
raise ValueError(
|
|
||||||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
|
|
||||||
)
|
|
||||||
tau_to_idx.append(j)
|
|
||||||
|
|
||||||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
if tau_f == 0.0:
|
||||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
u_list.append(0)
|
||||||
|
frac_list.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# u is the first bin whose end edge >= tau.
|
||||||
|
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
|
||||||
|
if u >= n_bins:
|
||||||
|
# Past the last finite edge: clamp to the last bin end.
|
||||||
|
u_list.append(n_bins - 1)
|
||||||
|
frac_list.append(1.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
|
||||||
|
end = float(finite_edges_arr[u])
|
||||||
|
width = end - start
|
||||||
|
if width <= 0.0:
|
||||||
|
raise ValueError("Invalid bin edges: non-positive bin width")
|
||||||
|
frac = (tau_f - start) / width
|
||||||
|
u_list.append(u)
|
||||||
|
frac_list.append(float(min(1.0, max(0.0, frac))))
|
||||||
|
|
||||||
|
u_idx = torch.tensor(u_list, device=logits.device, dtype=torch.long) # (H,)
|
||||||
|
frac_t = torch.tensor(frac_list, device=logits.device,
|
||||||
|
dtype=probs.dtype) # (H,)
|
||||||
|
|
||||||
|
# CIF before the containing bin (pad with a leading zero for u==0)
|
||||||
|
zeros_cif = torch.zeros((B, K, 1), device=logits.device, dtype=probs.dtype)
|
||||||
|
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,K,n_bins+1)
|
||||||
|
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,K,H)
|
||||||
|
|
||||||
|
# Survival before the containing bin
|
||||||
|
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
|
||||||
|
|
||||||
|
# Select per-bin parameters for the containing bin
|
||||||
|
hazards_u = hazards.index_select(dim=2, index=u_idx) # (B,K,H)
|
||||||
|
p_comp_u = p_comp.index_select(dim=1, index=u_idx) # (B,H)
|
||||||
|
|
||||||
|
eps = 1e-12
|
||||||
|
p_comp_u_clamped = torch.clamp(p_comp_u, min=eps, max=1.0)
|
||||||
|
h_total = -torch.log(p_comp_u_clamped) # (B,H)
|
||||||
|
|
||||||
|
denom = 1.0 - p_comp_u
|
||||||
|
no_event = denom <= eps
|
||||||
|
denom_safe = torch.clamp(denom, min=eps)
|
||||||
|
ratio = hazards_u / denom_safe.unsqueeze(1)
|
||||||
|
ratio = torch.where(no_event.unsqueeze(1), torch.zeros_like(ratio), ratio)
|
||||||
|
|
||||||
|
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
|
||||||
|
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
|
||||||
|
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,K,H)
|
||||||
|
|
||||||
|
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,K,H)
|
||||||
|
|
||||||
if not return_survival:
|
if not return_survival:
|
||||||
return cif
|
return cif
|
||||||
|
|
||||||
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
|
survival = s_u * torch.exp(-h_partial) # (B,H)
|
||||||
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
|
|
||||||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
|
||||||
return cif, survival
|
return cif, survival
|
||||||
|
|
||||||
|
|
||||||
@@ -439,7 +495,7 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
|
|||||||
"""Convert Route-3 binned hazard logits -> CIFs at taus.
|
"""Convert Route-3 binned hazard logits -> CIFs at taus.
|
||||||
|
|
||||||
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
|
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
|
||||||
taus are expected to align with finite bin edges.
|
taus may be any non-negative horizons (need not align to bin edges).
|
||||||
"""
|
"""
|
||||||
if logits.ndim not in {2, 3}:
|
if logits.ndim not in {2, 3}:
|
||||||
raise ValueError("logits must be 2D or 3D")
|
raise ValueError("logits must be 2D or 3D")
|
||||||
@@ -537,26 +593,70 @@ def cifs_from_lognormal_basis_binned_hazard_logits(
|
|||||||
1) * p_event, dim=2) # (B,J,n_bins)
|
1) * p_event, dim=2) # (B,J,n_bins)
|
||||||
|
|
||||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||||
tau_to_idx: List[int] = []
|
if finite_edges_arr.ndim != 1 or finite_edges_arr.size != n_bins:
|
||||||
|
raise ValueError("Unexpected finite_edges shape")
|
||||||
|
if finite_edges_arr.size == 0:
|
||||||
|
raise ValueError("No finite bin edges provided")
|
||||||
|
if np.any(~np.isfinite(finite_edges_arr)):
|
||||||
|
raise ValueError("finite_edges must be finite")
|
||||||
|
if np.any(np.diff(finite_edges_arr) <= 0):
|
||||||
|
raise ValueError("finite bin edges must be strictly increasing")
|
||||||
|
|
||||||
|
u_list: List[int] = []
|
||||||
|
frac_list: List[float] = []
|
||||||
for tau in taus:
|
for tau in taus:
|
||||||
tau_f = float(tau)
|
tau_f = float(tau)
|
||||||
if not math.isfinite(tau_f):
|
if not math.isfinite(tau_f):
|
||||||
raise ValueError("taus must be finite for discrete-time CIF")
|
raise ValueError("taus must be finite for discrete-time CIF")
|
||||||
diffs = np.abs(finite_edges_arr - tau_f)
|
if tau_f < 0.0:
|
||||||
idx0 = int(np.argmin(diffs))
|
raise ValueError("taus must be non-negative")
|
||||||
if diffs[idx0] > 1e-6:
|
|
||||||
raise ValueError(
|
|
||||||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
|
|
||||||
)
|
|
||||||
tau_to_idx.append(idx0)
|
|
||||||
|
|
||||||
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
|
if tau_f == 0.0:
|
||||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
|
u_list.append(0)
|
||||||
|
frac_list.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
u = int(np.searchsorted(finite_edges_arr, tau_f, side="left"))
|
||||||
|
if u >= n_bins:
|
||||||
|
u_list.append(n_bins - 1)
|
||||||
|
frac_list.append(1.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = 0.0 if u == 0 else float(finite_edges_arr[u - 1])
|
||||||
|
end = float(finite_edges_arr[u])
|
||||||
|
width = end - start
|
||||||
|
if width <= 0.0:
|
||||||
|
raise ValueError("Invalid bin edges: non-positive bin width")
|
||||||
|
frac = (tau_f - start) / width
|
||||||
|
u_list.append(u)
|
||||||
|
frac_list.append(float(min(1.0, max(0.0, frac))))
|
||||||
|
|
||||||
|
u_idx = torch.tensor(u_list, device=device, dtype=torch.long) # (H,)
|
||||||
|
frac_t = torch.tensor(frac_list, device=device, dtype=dtype) # (H,)
|
||||||
|
|
||||||
|
zeros_cif = torch.zeros((alpha.size(0), j, 1), device=device, dtype=dtype)
|
||||||
|
cif_padded = torch.cat([zeros_cif, cif_bins], dim=2) # (B,J,n_bins+1)
|
||||||
|
cif_before = cif_padded.index_select(dim=2, index=u_idx) # (B,J,H)
|
||||||
|
|
||||||
|
s_u = s_prev.index_select(dim=1, index=u_idx) # (B,H)
|
||||||
|
h_total = h_k.index_select(dim=1, index=u_idx) # (B,H)
|
||||||
|
h_j_total = h_jk.index_select(dim=2, index=u_idx) # (B,J,H)
|
||||||
|
|
||||||
|
h_total_safe = torch.clamp(h_total, min=eps)
|
||||||
|
ratio = h_j_total / h_total_safe.unsqueeze(1)
|
||||||
|
ratio = torch.where(h_total.unsqueeze(1) <= eps,
|
||||||
|
torch.zeros_like(ratio), ratio)
|
||||||
|
|
||||||
|
h_partial = h_total * frac_t.unsqueeze(0) # (B,H)
|
||||||
|
one_minus_partial = -torch.expm1(-h_partial) # (B,H)
|
||||||
|
p_event_partial = one_minus_partial.unsqueeze(1) * ratio # (B,J,H)
|
||||||
|
|
||||||
|
cif = cif_before + s_u.unsqueeze(1) * p_event_partial # (B,J,H)
|
||||||
|
|
||||||
if not return_survival:
|
if not return_survival:
|
||||||
return cif
|
return cif
|
||||||
|
|
||||||
survival = cum.index_select(dim=1, index=idx) # (B,H)
|
survival = s_u * torch.exp(-h_partial) # (B,H)
|
||||||
return cif, survival
|
return cif, survival
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user