Add Piecewise Exponential CIF Loss and update model evaluation for PWE

This commit is contained in:
2026-01-15 11:36:24 +08:00
parent d8b322cbee
commit 2f46acf2bd
3 changed files with 275 additions and 6 deletions

View File

@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif
loss_type: str # exponential | discrete_time_cif | pwe_cif
full_cov: bool
checkpoint_path: str
@@ -420,6 +420,94 @@ def cifs_from_discrete_time_logits(
return cif, survival
def cifs_from_pwe_logits(
logits: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
eps: float = 1e-6,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert piecewise-exponential (PWE) hazard logits -> CIFs at taus.
logits: (B, K, n_bins) # hazard logits per cause per bin
bin_edges: length n_bins+1, strictly increasing, finite last edge
taus: subset of finite bin edges (recommended)
returns: (B, K, H) or (cif, survival) if return_survival
"""
if logits.ndim != 3:
raise ValueError("Expected logits shape (B, K, n_bins) for pwe_cif")
edges = [float(x) for x in bin_edges]
if len(edges) < 2:
raise ValueError("bin_edges must have length >= 2")
if edges[0] != 0.0:
raise ValueError("bin_edges[0] must equal 0.0")
if not math.isfinite(edges[-1]):
raise ValueError(
"pwe_cif requires a finite last bin edge (no +inf). "
"If your training config uses +inf, drop it for PWE evaluation."
)
B, K, n_bins = logits.shape
if n_bins != (len(edges) - 1):
raise ValueError(
f"logits last dim n_bins={n_bins} must equal len(bin_edges)-1={len(edges)-1}"
)
# Convert logits -> hazards, then integrated hazards per bin.
hazards = F.softplus(logits) + eps # (B,K,n_bins)
dt_bins = torch.tensor(
[edges[i + 1] - edges[i] for i in range(n_bins)],
device=logits.device,
dtype=hazards.dtype,
) # (n_bins,)
if not torch.isfinite(dt_bins).all() or not (dt_bins > 0).all():
raise ValueError("All PWE bin widths must be finite and > 0")
H_cause = hazards * dt_bins.view(1, 1, n_bins) # (B,K,n_bins)
H_total = H_cause.sum(dim=1) # (B,n_bins)
# Survival at START of each bin u.
cum_total = torch.cumsum(H_total, dim=1) # (B,n_bins)
zeros = torch.zeros((B, 1), device=logits.device, dtype=hazards.dtype)
cum_prev = torch.cat([zeros, cum_total[:, :-1]], dim=1) # (B,n_bins)
S_prev = torch.exp(-cum_prev) # (B,n_bins)
one_minus_surv_bin = 1.0 - torch.exp(-H_total) # (B,n_bins)
frac = H_cause / torch.clamp(H_total.unsqueeze(1), min=eps) # (B,K,n_bins)
cif_incr = S_prev.unsqueeze(1) * frac * one_minus_surv_bin.unsqueeze(1)
cif_bins = torch.cumsum(cif_incr, dim=2) # (B,K,n_bins) at edges[1:]
# Map tau -> edge index in edges[1:]
finite_edges = edges[1:]
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for pwe_cif")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon is exp(-cum_total at that edge)
survival_bins = torch.exp(-cum_total) # (B,n_bins)
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
@@ -1196,11 +1284,21 @@ def instantiate_model_and_head(
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "pwe_cif":
# Match training: drop +inf if present and evaluate up to the last finite edge.
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
f"pwe_cif requires >=2 finite edges; got bin_edges={list(bin_edges)}"
)
n_bins = len(pwe_edges) - 1
out_dims = [dataset.n_disease, n_bins]
bin_edges = pwe_edges
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
@@ -1248,7 +1346,6 @@ def instantiate_model_and_head(
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges
@@ -1316,6 +1413,9 @@ def predict_cifs_for_model(
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "pwe_cif":
cif_full, survival = cifs_from_pwe_logits(
logits, bin_edges, eval_horizons, return_survival=True)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")