Add Piecewise Exponential CIF Loss and update model evaluation for PWE
This commit is contained in:
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
|
||||
class ModelSpec:
|
||||
name: str
|
||||
model_type: str # delphi_fork | sap_delphi
|
||||
loss_type: str # exponential | discrete_time_cif
|
||||
loss_type: str # exponential | discrete_time_cif | pwe_cif
|
||||
full_cov: bool
|
||||
checkpoint_path: str
|
||||
|
||||
@@ -420,6 +420,94 @@ def cifs_from_discrete_time_logits(
|
||||
return cif, survival
|
||||
|
||||
|
||||
def cifs_from_pwe_logits(
|
||||
logits: torch.Tensor,
|
||||
bin_edges: Sequence[float],
|
||||
taus: Sequence[float],
|
||||
eps: float = 1e-6,
|
||||
return_survival: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Convert piecewise-exponential (PWE) hazard logits -> CIFs at taus.
|
||||
|
||||
logits: (B, K, n_bins) # hazard logits per cause per bin
|
||||
bin_edges: length n_bins+1, strictly increasing, finite last edge
|
||||
taus: subset of finite bin edges (recommended)
|
||||
|
||||
returns: (B, K, H) or (cif, survival) if return_survival
|
||||
"""
|
||||
if logits.ndim != 3:
|
||||
raise ValueError("Expected logits shape (B, K, n_bins) for pwe_cif")
|
||||
|
||||
edges = [float(x) for x in bin_edges]
|
||||
if len(edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2")
|
||||
if edges[0] != 0.0:
|
||||
raise ValueError("bin_edges[0] must equal 0.0")
|
||||
if not math.isfinite(edges[-1]):
|
||||
raise ValueError(
|
||||
"pwe_cif requires a finite last bin edge (no +inf). "
|
||||
"If your training config uses +inf, drop it for PWE evaluation."
|
||||
)
|
||||
|
||||
B, K, n_bins = logits.shape
|
||||
if n_bins != (len(edges) - 1):
|
||||
raise ValueError(
|
||||
f"logits last dim n_bins={n_bins} must equal len(bin_edges)-1={len(edges)-1}"
|
||||
)
|
||||
|
||||
# Convert logits -> hazards, then integrated hazards per bin.
|
||||
hazards = F.softplus(logits) + eps # (B,K,n_bins)
|
||||
dt_bins = torch.tensor(
|
||||
[edges[i + 1] - edges[i] for i in range(n_bins)],
|
||||
device=logits.device,
|
||||
dtype=hazards.dtype,
|
||||
) # (n_bins,)
|
||||
if not torch.isfinite(dt_bins).all() or not (dt_bins > 0).all():
|
||||
raise ValueError("All PWE bin widths must be finite and > 0")
|
||||
|
||||
H_cause = hazards * dt_bins.view(1, 1, n_bins) # (B,K,n_bins)
|
||||
H_total = H_cause.sum(dim=1) # (B,n_bins)
|
||||
|
||||
# Survival at START of each bin u.
|
||||
cum_total = torch.cumsum(H_total, dim=1) # (B,n_bins)
|
||||
zeros = torch.zeros((B, 1), device=logits.device, dtype=hazards.dtype)
|
||||
cum_prev = torch.cat([zeros, cum_total[:, :-1]], dim=1) # (B,n_bins)
|
||||
S_prev = torch.exp(-cum_prev) # (B,n_bins)
|
||||
|
||||
one_minus_surv_bin = 1.0 - torch.exp(-H_total) # (B,n_bins)
|
||||
frac = H_cause / torch.clamp(H_total.unsqueeze(1), min=eps) # (B,K,n_bins)
|
||||
|
||||
cif_incr = S_prev.unsqueeze(1) * frac * one_minus_surv_bin.unsqueeze(1)
|
||||
cif_bins = torch.cumsum(cif_incr, dim=2) # (B,K,n_bins) at edges[1:]
|
||||
|
||||
# Map tau -> edge index in edges[1:]
|
||||
finite_edges = edges[1:]
|
||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||
tau_to_idx: List[int] = []
|
||||
for tau in taus:
|
||||
tau_f = float(tau)
|
||||
if not math.isfinite(tau_f):
|
||||
raise ValueError("taus must be finite for pwe_cif")
|
||||
diffs = np.abs(finite_edges_arr - tau_f)
|
||||
j = int(np.argmin(diffs))
|
||||
if diffs[j] > 1e-6:
|
||||
raise ValueError(
|
||||
f"tau={tau_f} not close to any bin edge (min |edge-tau|={diffs[j]})"
|
||||
)
|
||||
tau_to_idx.append(j)
|
||||
|
||||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
# Survival at each horizon is exp(-cum_total at that edge)
|
||||
survival_bins = torch.exp(-cum_total) # (B,n_bins)
|
||||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
||||
return cif, survival
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CIF integrity checks
|
||||
# ============================================================
|
||||
@@ -1196,11 +1284,21 @@ def instantiate_model_and_head(
|
||||
model_type = str(cfg["model_type"])
|
||||
loss_type = str(cfg["loss_type"])
|
||||
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
if loss_type == "exponential":
|
||||
out_dims = [dataset.n_disease]
|
||||
elif loss_type == "discrete_time_cif":
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||||
elif loss_type == "pwe_cif":
|
||||
# Match training: drop +inf if present and evaluate up to the last finite edge.
|
||||
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
||||
if len(pwe_edges) < 2:
|
||||
raise ValueError(
|
||||
f"pwe_cif requires >=2 finite edges; got bin_edges={list(bin_edges)}"
|
||||
)
|
||||
n_bins = len(pwe_edges) - 1
|
||||
out_dims = [dataset.n_disease, n_bins]
|
||||
bin_edges = pwe_edges
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||||
|
||||
@@ -1248,7 +1346,6 @@ def instantiate_model_and_head(
|
||||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||||
|
||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
return backbone, head, loss_type, bin_edges
|
||||
|
||||
|
||||
@@ -1316,6 +1413,9 @@ def predict_cifs_for_model(
|
||||
cif_full, survival = cifs_from_discrete_time_logits(
|
||||
# (B,K,H), (B,H)
|
||||
logits, bin_edges, eval_horizons, return_survival=True)
|
||||
elif loss_type == "pwe_cif":
|
||||
cif_full, survival = cifs_from_pwe_logits(
|
||||
logits, bin_edges, eval_horizons, return_survival=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||
|
||||
|
||||
146
losses.py
146
losses.py
@@ -258,3 +258,149 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
||||
|
||||
return nll, reg
|
||||
|
||||
|
||||
class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
||||
"""
|
||||
Piecewise-Exponential (PWE) cause-specific hazards with discrete-time CIF likelihood.
|
||||
- No censoring
|
||||
- No regularization (reg always 0)
|
||||
- Forward signature matches DiscreteTimeCIFNLLLoss:
|
||||
forward(logits, target_events, dt, reduction) -> (nll, reg)
|
||||
|
||||
Expected shapes:
|
||||
logits: (M, K, n_bins) # hazard logits per cause per bin
|
||||
target_events: (M,) long in [0, K-1]
|
||||
dt: (M,) event times (strictly > 0)
|
||||
|
||||
bin_edges:
|
||||
length n_bins+1, strictly increasing, bin_edges[0]==0,
|
||||
and MUST be finite at the last edge (no +inf) for PWE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bin_edges: Sequence[float],
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0, # kept for signature compatibility; UNUSED
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if len(bin_edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
||||
if float(bin_edges[0]) != 0.0:
|
||||
raise ValueError("bin_edges[0] must equal 0.0")
|
||||
for i in range(1, len(bin_edges)):
|
||||
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
if math.isinf(float(bin_edges[-1])):
|
||||
raise ValueError(
|
||||
"PiecewiseExponentialCIFNLLLoss requires a finite last bin edge (no +inf). "
|
||||
"Use a finite truncation horizon for PWE."
|
||||
)
|
||||
|
||||
self.eps = float(eps)
|
||||
# unused, kept only for interface compatibility
|
||||
self.lambda_reg = float(lambda_reg)
|
||||
|
||||
self.register_buffer(
|
||||
"bin_edges",
|
||||
torch.tensor([float(x) for x in bin_edges], dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
if logits.ndim != 3:
|
||||
raise ValueError(
|
||||
f"logits must be 3D (M, K, n_bins); got shape={tuple(logits.shape)}")
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError("target_events and dt must be 1D tensors")
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch among logits, target_events, dt")
|
||||
|
||||
if not torch.all(dt > 0):
|
||||
raise ValueError(
|
||||
"dt must be strictly positive (no censoring supported here)")
|
||||
|
||||
M, K, n_bins = logits.shape
|
||||
|
||||
if target_events.dtype != torch.long:
|
||||
target_events = target_events.to(torch.long)
|
||||
if (target_events < 0).any() or (target_events >= K).any():
|
||||
raise ValueError(f"target_events must be in [0, {K-1}]")
|
||||
|
||||
# Prepare bin_edges / bin widths
|
||||
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
||||
if bin_edges.numel() != n_bins + 1:
|
||||
raise ValueError(
|
||||
f"bin_edges length must be n_bins+1={n_bins+1}; got {bin_edges.numel()}"
|
||||
)
|
||||
|
||||
dt_bins = (bin_edges[1:] - bin_edges[:-1]
|
||||
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
|
||||
if not torch.isfinite(dt_bins).all():
|
||||
raise ValueError("All bin widths must be finite for PWE.")
|
||||
if not (dt_bins > 0).all():
|
||||
raise ValueError(
|
||||
"All bin widths must be strictly positive for PWE.")
|
||||
|
||||
# Map event time -> bin index k* in {1..n_bins}
|
||||
# (same convention as your discrete_time_cif: clamp to [1, n_bins])
|
||||
time_bin = torch.bucketize(dt, bin_edges)
|
||||
time_bin = torch.clamp(
|
||||
time_bin, min=1, max=n_bins).to(torch.long) # (M,)
|
||||
k0 = time_bin - 1 # 0..n_bins-1
|
||||
|
||||
# Nonnegative hazards per cause per bin
|
||||
hazards = F.softplus(logits) + self.eps # (M, K, n_bins)
|
||||
|
||||
# Integrated hazards H_{j,k} = lambda_{j,k} * Δt_k
|
||||
H_jk = hazards * dt_bins.view(1, 1, n_bins) # (M, K, n_bins)
|
||||
H_k = H_jk.sum(dim=1) # (M, n_bins)
|
||||
|
||||
# Previous survival term: Σ_{u<k*} H_u
|
||||
bins = torch.arange(
|
||||
1, n_bins + 1, device=logits.device).unsqueeze(0) # (1, n_bins)
|
||||
mask_prev = bins < time_bin.unsqueeze(1) # (M, n_bins)
|
||||
loss_prev = (H_k * mask_prev.to(H_k.dtype)).sum(dim=1) # (M,)
|
||||
|
||||
# Event term at k*: -log p_{k*}(cause)
|
||||
m_idx = torch.arange(M, device=logits.device)
|
||||
|
||||
H_event_total = torch.clamp(H_k[m_idx, k0], min=self.eps) # (M,)
|
||||
H_event_cause = torch.clamp(
|
||||
H_jk[m_idx, target_events, k0], min=self.eps) # (M,)
|
||||
|
||||
# log(1 - exp(-H)) stable
|
||||
log1mexp = torch.log(-torch.expm1(-H_event_total)) # (M,)
|
||||
loss_event = -log1mexp - \
|
||||
torch.log(H_event_cause) + torch.log(H_event_total)
|
||||
|
||||
loss_vec = loss_prev + loss_event # (M,)
|
||||
|
||||
if reduction == "mean":
|
||||
nll = loss_vec.mean()
|
||||
elif reduction == "sum":
|
||||
nll = loss_vec.sum()
|
||||
else:
|
||||
nll = loss_vec
|
||||
|
||||
if self.lambda_reg > 0.0 and n_bins >= 3:
|
||||
log_h = torch.log(hazards) # (M, K, n_bins)
|
||||
d2 = log_h[:, :, 2:] - 2.0 * log_h[:, :, 1:-1] + \
|
||||
log_h[:, :, :-2] # (M, K, n_bins-2)
|
||||
reg = self.lambda_reg * (d2.pow(2).mean())
|
||||
else:
|
||||
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
|
||||
|
||||
return nll, reg
|
||||
|
||||
29
train.py
29
train.py
@@ -1,4 +1,4 @@
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
@@ -22,7 +22,8 @@ from typing import Literal, Sequence
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
|
||||
loss_type: Literal['exponential',
|
||||
'discrete_time_cif', 'pwe_cif'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -64,7 +65,7 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif'],
|
||||
choices=['exponential', 'discrete_time_cif', 'pwe_cif'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
@@ -205,6 +206,28 @@ class Trainer:
|
||||
).to(self.device)
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
elif cfg.loss_type == "pwe_cif":
|
||||
# Piecewise-exponential (PWE) requires a FINITE last edge.
|
||||
# If cfg.bin_edges ends with +inf (default), drop it and train up to the last finite edge.
|
||||
pwe_edges = [float(x)
|
||||
for x in cfg.bin_edges if math.isfinite(float(x))]
|
||||
if len(pwe_edges) < 2:
|
||||
raise ValueError(
|
||||
"pwe_cif requires at least 2 finite bin edges (including 0). "
|
||||
f"Got bin_edges={list(cfg.bin_edges)}"
|
||||
)
|
||||
if pwe_edges[0] != 0.0:
|
||||
raise ValueError(
|
||||
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}"
|
||||
)
|
||||
|
||||
self.criterion = PiecewiseExponentialCIFNLLLoss(
|
||||
bin_edges=pwe_edges,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_bins = len(pwe_edges) - 1
|
||||
# logits shape (M, K, n_bins)
|
||||
out_dims = [dataset.n_disease, n_bins]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user