Add evaluation and utility functions for time-dependent metrics
- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference. - Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds. - Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models. - Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
This commit is contained in:
386
losses.py
386
losses.py
@@ -131,6 +131,96 @@ class ExponentialNLLLoss(nn.Module):
|
||||
reduction="mean") * self.lambda_reg
|
||||
return nll, reg
|
||||
|
||||
def calculate_cifs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
taus: torch.Tensor,
|
||||
eps: Optional[float] = None,
|
||||
return_survival: bool = False,
|
||||
):
|
||||
"""Compute CIFs for a competing-risks exponential model.
|
||||
|
||||
Model assumptions:
|
||||
- cause-specific hazards are constant in time within a sample.
|
||||
- hazards are obtained via softplus(logits) + eps.
|
||||
|
||||
Args:
|
||||
logits: (M, K) or (M, K, 1) tensor.
|
||||
taus: scalar, (T,), (M,), or (M, T) times (>=0 recommended).
|
||||
eps: overrides self.eps for numerical stability.
|
||||
return_survival: if True, also return survival S(tau).
|
||||
|
||||
Returns:
|
||||
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||
"""
|
||||
|
||||
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||
scalar_out = False
|
||||
kind = "T" # one of: 'T', 'per_sample', 'MT'
|
||||
if t.ndim == 0:
|
||||
t = t.view(1)
|
||||
scalar_out = True
|
||||
t = t.view(1, 1) # (1,1)
|
||||
kind = "T"
|
||||
elif t.ndim == 1:
|
||||
if t.shape[0] == batch_size:
|
||||
t = t.view(batch_size, 1) # (M,1)
|
||||
kind = "per_sample"
|
||||
else:
|
||||
t = t.view(1, -1) # (1,T)
|
||||
kind = "T"
|
||||
elif t.ndim == 2:
|
||||
if t.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||
)
|
||||
kind = "MT"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||
return t, kind, scalar_out
|
||||
|
||||
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
||||
if logits.ndim != 2:
|
||||
raise ValueError(
|
||||
f"logits must be 2D (M,K) (or 3D with last dim 1); got shape={tuple(logits.shape)}")
|
||||
|
||||
M, K = logits.shape
|
||||
used_eps = float(self.eps if eps is None else eps)
|
||||
|
||||
hazards = F.softplus(logits) + used_eps # (M, K)
|
||||
total_hazard = hazards.sum(dim=1, keepdim=True) # (M, 1)
|
||||
total_hazard = torch.clamp(total_hazard, min=used_eps)
|
||||
|
||||
frac = hazards / total_hazard # (M, K)
|
||||
|
||||
taus_t, kind, scalar_out = _prepare_taus(
|
||||
taus, M, logits.device, hazards.dtype)
|
||||
taus_t = torch.clamp(taus_t, min=0)
|
||||
|
||||
if kind == "T":
|
||||
# taus_t: (1,T)
|
||||
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,T)
|
||||
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) # (M,K,T)
|
||||
survival = torch.exp(-total_hazard * taus_t) # (M,T)
|
||||
else:
|
||||
# taus_t: (M,1) or (M,T)
|
||||
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
||||
# (M,K,1) or (M,K,T)
|
||||
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1)
|
||||
survival = torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
||||
|
||||
if kind == "per_sample":
|
||||
cifs = cifs.squeeze(-1) # (M,K)
|
||||
survival = survival.squeeze(-1) # (M,)
|
||||
elif scalar_out:
|
||||
cifs = cifs.squeeze(-1) # (M,K)
|
||||
survival = survival.squeeze(-1) # (M,)
|
||||
|
||||
return (cifs, survival) if return_survival else cifs
|
||||
|
||||
|
||||
class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
||||
@@ -259,6 +349,122 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
|
||||
return nll, reg
|
||||
|
||||
def calculate_cifs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
taus: torch.Tensor,
|
||||
eps: Optional[float] = None,
|
||||
return_survival: bool = False,
|
||||
):
|
||||
"""Compute discrete-time CIFs implied by per-bin (K causes + complement) logits.
|
||||
|
||||
This matches the likelihood used in forward():
|
||||
p(event=cause k at bin j) = Π_{u=1}^{j-1} p(comp at u) * p(k at j)
|
||||
|
||||
Args:
|
||||
logits: (M, K+1, n_bins+1) where channel K is complement.
|
||||
taus: scalar, (T,), (M,), or (M,T) continuous times.
|
||||
eps: unused (kept for signature compatibility).
|
||||
return_survival: if True, also return survival probability up to the mapped bin.
|
||||
|
||||
Returns:
|
||||
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||
"""
|
||||
|
||||
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||
scalar_out = False
|
||||
kind = "T"
|
||||
if t.ndim == 0:
|
||||
t = t.view(1)
|
||||
scalar_out = True
|
||||
t = t.view(1, 1)
|
||||
kind = "T"
|
||||
elif t.ndim == 1:
|
||||
if t.shape[0] == batch_size:
|
||||
t = t.view(batch_size, 1)
|
||||
kind = "per_sample"
|
||||
else:
|
||||
t = t.view(1, -1)
|
||||
kind = "T"
|
||||
elif t.ndim == 2:
|
||||
if t.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||
)
|
||||
kind = "MT"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||
return t, kind, scalar_out
|
||||
|
||||
if logits.ndim != 3:
|
||||
raise ValueError(
|
||||
f"logits must have shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
||||
)
|
||||
|
||||
M, k_plus_1, n_bins_plus_1 = logits.shape
|
||||
K = k_plus_1 - 1
|
||||
if K < 1:
|
||||
raise ValueError(
|
||||
"logits.shape[1] must be at least 2 (K>=1 plus complement)")
|
||||
|
||||
n_bins = int(self.bin_edges.numel() - 1)
|
||||
if n_bins_plus_1 != n_bins + 1:
|
||||
raise ValueError(
|
||||
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
||||
)
|
||||
|
||||
# probs over causes+complement per bin
|
||||
probs = F.softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
||||
p_causes = probs[:, :K, 1:] # (M, K, n_bins)
|
||||
p_comp = probs[:, K, 1:] # (M, n_bins)
|
||||
|
||||
# survival up to end of each bin (1..n_bins)
|
||||
surv_end = torch.cumprod(p_comp, dim=1) # (M, n_bins)
|
||||
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
||||
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
||||
|
||||
inc = surv_start.unsqueeze(1) * p_causes # (M, K, n_bins)
|
||||
cif_full = torch.cumsum(inc, dim=2) # (M, K, n_bins)
|
||||
|
||||
taus_t, kind, scalar_out = _prepare_taus(
|
||||
taus, M, logits.device, surv_end.dtype)
|
||||
taus_t = torch.clamp(taus_t, min=0)
|
||||
|
||||
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
||||
time_bin = torch.bucketize(taus_t, bin_edges) # (..)
|
||||
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(torch.long)
|
||||
|
||||
if kind == "T":
|
||||
# (1,T) -> expand to (M,T)
|
||||
time_bin = time_bin.expand(M, -1)
|
||||
# kind per_sample gives (M,1), MT gives (M,T)
|
||||
|
||||
idx = torch.clamp(time_bin - 1, min=0) # (M,T)
|
||||
|
||||
gathered_cif = cif_full.gather(
|
||||
dim=2,
|
||||
index=idx.unsqueeze(1).expand(-1, K, -1),
|
||||
) # (M,K,T)
|
||||
gathered_surv = surv_end.gather(dim=1, index=idx) # (M,T)
|
||||
|
||||
# tau mapped to bin 0 => CIF=0, survival=1
|
||||
zero_mask = (time_bin == 0)
|
||||
if zero_mask.any():
|
||||
gathered_cif = gathered_cif.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
||||
gathered_surv = gathered_surv.masked_fill(zero_mask, 1.0)
|
||||
|
||||
if kind == "per_sample":
|
||||
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
||||
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
||||
elif scalar_out:
|
||||
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
||||
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
||||
|
||||
return (gathered_cif, gathered_surv) if return_survival else gathered_cif
|
||||
|
||||
|
||||
class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
||||
"""
|
||||
@@ -404,3 +610,183 @@ class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
||||
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
|
||||
|
||||
return nll, reg
|
||||
|
||||
def calculate_cifs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
taus: torch.Tensor,
|
||||
eps: Optional[float] = None,
|
||||
return_survival: bool = False,
|
||||
):
|
||||
"""Compute CIFs for piecewise-constant cause-specific hazards.
|
||||
|
||||
Uses the same binning convention as forward(): taus are mapped to a bin via
|
||||
torch.bucketize(taus, bin_edges), clamped to [0, n_bins]. tau<=0 maps to 0.
|
||||
|
||||
Args:
|
||||
logits: (M, K, n_bins) hazard logits per cause per bin.
|
||||
taus: scalar, (T,), (M,), or (M,T) times.
|
||||
eps: overrides self.eps for numerical stability.
|
||||
return_survival: if True, also return survival S(tau).
|
||||
|
||||
Returns:
|
||||
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||
"""
|
||||
|
||||
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||
scalar_out = False
|
||||
kind = "T"
|
||||
if t.ndim == 0:
|
||||
t = t.view(1)
|
||||
scalar_out = True
|
||||
t = t.view(1, 1)
|
||||
kind = "T"
|
||||
elif t.ndim == 1:
|
||||
if t.shape[0] == batch_size:
|
||||
t = t.view(batch_size, 1)
|
||||
kind = "per_sample"
|
||||
else:
|
||||
t = t.view(1, -1)
|
||||
kind = "T"
|
||||
elif t.ndim == 2:
|
||||
if t.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||
)
|
||||
kind = "MT"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||
return t, kind, scalar_out
|
||||
|
||||
if logits.ndim != 3:
|
||||
raise ValueError(
|
||||
f"logits must be 3D (M,K,n_bins); got shape={tuple(logits.shape)}")
|
||||
|
||||
M, K, n_bins = logits.shape
|
||||
if self.bin_edges.numel() != n_bins + 1:
|
||||
raise ValueError(
|
||||
f"bin_edges length must be n_bins+1={n_bins+1}; got {self.bin_edges.numel()}"
|
||||
)
|
||||
|
||||
used_eps = float(self.eps if eps is None else eps)
|
||||
|
||||
taus_t, kind, scalar_out = _prepare_taus(
|
||||
taus, M, logits.device, logits.dtype)
|
||||
taus_t = torch.clamp(taus_t, min=0)
|
||||
|
||||
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
||||
dt_bins = (bin_edges[1:] - bin_edges[:-1]
|
||||
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
|
||||
|
||||
hazards = F.softplus(logits) + used_eps # (M, K, n_bins)
|
||||
total_h = hazards.sum(dim=1) # (M, n_bins)
|
||||
total_h = torch.clamp(total_h, min=used_eps)
|
||||
|
||||
# Precompute full-bin CIF increments
|
||||
H_total_bin = total_h * dt_bins.view(1, n_bins) # (M, n_bins)
|
||||
cum_H_end = torch.cumsum(H_total_bin, dim=1) # (M, n_bins)
|
||||
surv_end = torch.exp(-cum_H_end) # (M, n_bins)
|
||||
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
||||
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
||||
|
||||
frac = hazards / total_h.unsqueeze(1) # (M, K, n_bins)
|
||||
one_minus = 1.0 - \
|
||||
torch.exp(-total_h * dt_bins.view(1, n_bins)) # (M, n_bins)
|
||||
inc_full = surv_start.unsqueeze(
|
||||
1) * frac * one_minus.unsqueeze(1) # (M, K, n_bins)
|
||||
cif_full = torch.cumsum(inc_full, dim=2) # (M, K, n_bins)
|
||||
|
||||
# Map taus -> bin index b in [0..n_bins]
|
||||
time_bin = torch.bucketize(taus_t, bin_edges)
|
||||
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(
|
||||
torch.long) # (...)
|
||||
|
||||
if kind == "T":
|
||||
time_bin = time_bin.expand(M, -1) # (M,T)
|
||||
|
||||
# Compute within-bin length l and indices
|
||||
b = time_bin # (M,T)
|
||||
idx_bin0 = torch.clamp(b - 1, min=0) # 0..n_bins-1
|
||||
|
||||
# Start-of-bin survival for the current bin (for b==0 it's unused)
|
||||
S_start_b = surv_start.gather(dim=1, index=idx_bin0) # (M,T)
|
||||
|
||||
# Length into bin: l = tau - edge[b-1], clamped to [0, dt_bin]
|
||||
left_edge = bin_edges.gather(
|
||||
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0).to(taus_t.dtype)
|
||||
l = taus_t.expand_as(b) - left_edge
|
||||
l = torch.clamp(l, min=0)
|
||||
width_b = dt_bins.gather(
|
||||
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0)
|
||||
l = torch.min(l, width_b.to(l.dtype))
|
||||
|
||||
# CIF up to previous full bins
|
||||
# if b<=1 => 0 else cif_full at (b-2)
|
||||
prev_idx = torch.clamp(b - 2, min=0)
|
||||
cif_before = cif_full.gather(
|
||||
dim=2,
|
||||
index=prev_idx.unsqueeze(1).expand(-1, K, -1),
|
||||
) # (M,K,T)
|
||||
if (b <= 1).any():
|
||||
cif_before = cif_before.masked_fill((b <= 1).unsqueeze(1), 0.0)
|
||||
|
||||
# Partial increment in current bin
|
||||
total_h_b = total_h.gather(dim=1, index=idx_bin0) # (M,T)
|
||||
haz_b = hazards.gather(
|
||||
dim=2,
|
||||
index=idx_bin0.unsqueeze(1).expand(-1, K, -1),
|
||||
) # (M,K,T)
|
||||
frac_b = haz_b / total_h_b.unsqueeze(1) # (M,K,T)
|
||||
|
||||
one_minus_partial = 1.0 - torch.exp(-total_h_b * l) # (M,T)
|
||||
inc_partial = S_start_b.unsqueeze(
|
||||
1) * frac_b * one_minus_partial.unsqueeze(1) # (M,K,T)
|
||||
|
||||
cifs = cif_before + inc_partial
|
||||
|
||||
survival = S_start_b * torch.exp(-total_h_b * l) # (M,T)
|
||||
|
||||
# Inference-only tail extension beyond the last finite edge.
|
||||
# For tau > t_B (t_B = bin_edges[-1]), extend survival and CIFs using
|
||||
# constant hazards from the final bin B:
|
||||
# S(tau)=S(t_B) * exp(-Λ_B * (tau - t_B))
|
||||
# F_k(tau)=F_k(t_B) + S(t_B) * (λ_{k,B}/Λ_B) * (1 - exp(-Λ_B*(tau-t_B)))
|
||||
last_edge = bin_edges[-1]
|
||||
tau_full = taus_t.expand_as(b) # (M,T)
|
||||
tail_mask = tau_full > last_edge
|
||||
if tail_mask.any():
|
||||
delta = torch.clamp(tau_full - last_edge, min=0) # (M,T)
|
||||
|
||||
S_B = surv_end[:, -1].unsqueeze(1) # (M,1)
|
||||
F_B = cif_full[:, :, -1].unsqueeze(-1) # (M,K,1)
|
||||
|
||||
lambda_last = hazards[:, :, -1] # (M,K)
|
||||
Lambda_last = torch.clamp(
|
||||
total_h[:, -1], min=used_eps).unsqueeze(1) # (M,1)
|
||||
|
||||
exp_tail = torch.exp(-Lambda_last * delta) # (M,T)
|
||||
survival_tail = S_B * exp_tail # (M,T)
|
||||
cifs_tail = F_B + \
|
||||
S_B.unsqueeze(
|
||||
1) * (lambda_last / Lambda_last).unsqueeze(-1) * (1.0 - exp_tail).unsqueeze(1)
|
||||
|
||||
survival = torch.where(tail_mask, survival_tail, survival)
|
||||
cifs = torch.where(tail_mask.unsqueeze(1), cifs_tail, cifs)
|
||||
|
||||
# tau mapped to bin 0 => CIF=0, survival=1
|
||||
zero_mask = (b == 0)
|
||||
if zero_mask.any():
|
||||
cifs = cifs.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
||||
survival = survival.masked_fill(zero_mask, 1.0)
|
||||
|
||||
if kind == "per_sample":
|
||||
cifs = cifs.squeeze(-1) # (M,K)
|
||||
survival = survival.squeeze(-1) # (M,)
|
||||
elif scalar_out:
|
||||
cifs = cifs.squeeze(-1) # (M,K)
|
||||
survival = survival.squeeze(-1) # (M,)
|
||||
|
||||
return (cifs, survival) if return_survival else cifs
|
||||
|
||||
Reference in New Issue
Block a user