2026-01-07 21:32:00 +08:00
|
|
|
import math
|
|
|
|
|
from typing import Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Pair extraction (utility; not used by the losses below)
|
|
|
|
|
# ============================================================
|
|
|
|
|
def get_valid_pairs_and_dt(
|
|
|
|
|
event_seqs: torch.Tensor,
|
|
|
|
|
time_seqs: torch.Tensor,
|
|
|
|
|
n_tech_tokens: int
|
|
|
|
|
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract valid event pairs (prev -> next) and compute dt in years.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
event_seqs (torch.Tensor): Event sequences.
|
|
|
|
|
time_seqs (torch.Tensor): Time sequences.
|
|
|
|
|
n_tech_tokens (int): Number of technical tokens.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
- Assumes strict right-padding.
|
|
|
|
|
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
|
|
|
|
|
- Filters to strictly positive dt.
|
|
|
|
|
"""
|
|
|
|
|
real_mask = event_seqs >= 1
|
|
|
|
|
idx = real_mask.nonzero(as_tuple=False)
|
|
|
|
|
|
|
|
|
|
if idx.size(0) <= 1:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
same_batch = idx[1:, 0] == idx[:-1, 0]
|
|
|
|
|
if not same_batch.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = idx[:-1][same_batch]
|
|
|
|
|
next_idx = idx[1:][same_batch]
|
|
|
|
|
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
|
|
|
|
|
if not valid_target.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
prev_idx = prev_idx[valid_target]
|
|
|
|
|
next_idx = next_idx[valid_target]
|
|
|
|
|
|
|
|
|
|
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
|
|
|
|
|
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
|
|
|
|
|
|
|
|
|
dt = (time_seqs[b_next, t_next] -
|
|
|
|
|
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
|
|
|
|
|
valid_dt = dt > 0
|
|
|
|
|
if not valid_dt.any():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
dt = dt[valid_dt]
|
|
|
|
|
b_prev = b_prev[valid_dt]
|
|
|
|
|
t_prev = t_prev[valid_dt]
|
|
|
|
|
b_next = b_next[valid_dt]
|
|
|
|
|
t_next = t_next[valid_dt]
|
|
|
|
|
|
|
|
|
|
return dt, b_prev, t_prev, b_next, t_next
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
|
|
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
|
|
|
|
|
# ============================================================
|
|
|
|
|
class ExponentialNLLLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Competing risks exponential likelihood.
|
|
|
|
|
|
|
|
|
|
The negative log-likelihood is given by:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
eps (float): Small epsilon for numerical stability.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.lambda_reg = lambda_reg
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Forward pass.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits (torch.Tensor): (M, K) tensor of logits.
|
|
|
|
|
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
|
|
|
|
|
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
|
|
|
|
|
reduction (str): 'mean', 'sum', or 'none'.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
|
|
|
|
|
"""
|
|
|
|
|
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
|
|
|
|
hazards = F.softplus(logits) + self.eps # (M,K)
|
|
|
|
|
hazard_event = hazards.gather(
|
|
|
|
|
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
|
|
|
|
|
total_hazard = hazards.sum(dim=1) # (M,)
|
|
|
|
|
log_hazards = torch.log(hazards) # (M, K)
|
|
|
|
|
nll = -torch.log(hazard_event) + total_hazard * dt
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = nll.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll = nll.sum()
|
|
|
|
|
|
|
|
|
|
reg = F.cross_entropy(log_hazards, target_events,
|
|
|
|
|
reduction="mean") * self.lambda_reg
|
|
|
|
|
return nll, reg
|
|
|
|
|
|
|
|
|
|
|
2026-01-08 12:45:31 +08:00
|
|
|
class PiecewiseExponentialLoss(nn.Module):
|
|
|
|
|
"""Piecewise-constant competing risks exponential likelihood.
|
|
|
|
|
|
|
|
|
|
Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0).
|
|
|
|
|
Within each bin b, hazards are constant and parameterized as:
|
|
|
|
|
|
|
|
|
|
hazards = softplus(logits) + eps with logits shape (M, K, B)
|
|
|
|
|
|
|
|
|
|
For each sample i, dt is bucketized to bin b* and the NLL is:
|
|
|
|
|
|
|
|
|
|
nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du
|
|
|
|
|
|
|
|
|
|
The integral is computed in closed form by summing full bins plus the partial bin b*.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
bin_edges: Sequence[float],
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if len(bin_edges) < 2:
|
|
|
|
|
raise ValueError("bin_edges must have length >= 2")
|
|
|
|
|
if bin_edges[0] != 0:
|
|
|
|
|
raise ValueError("bin_edges must start at 0")
|
|
|
|
|
for i in range(1, len(bin_edges)):
|
|
|
|
|
if not (bin_edges[i] > bin_edges[i - 1]):
|
|
|
|
|
raise ValueError("bin_edges must be strictly increasing")
|
|
|
|
|
|
|
|
|
|
self.eps = float(eps)
|
|
|
|
|
self.lambda_reg = float(lambda_reg)
|
|
|
|
|
|
|
|
|
|
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
|
|
|
|
self.register_buffer("bin_edges", edges, persistent=False)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
if logits.dim() != 3:
|
|
|
|
|
raise ValueError("logits must have shape (M, K, B)")
|
|
|
|
|
|
|
|
|
|
M, K, B = logits.shape
|
|
|
|
|
if self.bin_edges.numel() != B + 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
device = logits.device
|
|
|
|
|
dt = dt.to(device=device)
|
|
|
|
|
target_events = target_events.to(device=device)
|
|
|
|
|
|
|
|
|
|
# Build a per-sample finite mask to avoid NaN/Inf propagation.
|
|
|
|
|
logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1)
|
|
|
|
|
dt_finite = torch.isfinite(dt)
|
|
|
|
|
target_finite = torch.isfinite(target_events)
|
|
|
|
|
finite_mask = logits_finite & dt_finite & target_finite
|
|
|
|
|
|
|
|
|
|
nll_full = logits.new_zeros((M,))
|
|
|
|
|
|
|
|
|
|
if not finite_mask.any():
|
|
|
|
|
nll_out = nll_full if reduction == "none" else logits.new_zeros(())
|
|
|
|
|
reg_out = logits.new_zeros(())
|
|
|
|
|
return nll_out, reg_out
|
|
|
|
|
|
|
|
|
|
idx = finite_mask.nonzero(as_tuple=False).squeeze(1)
|
|
|
|
|
logits_v = logits[idx]
|
|
|
|
|
target_v = target_events[idx].to(torch.long)
|
|
|
|
|
dt_v = dt[idx].to(torch.float32)
|
|
|
|
|
|
|
|
|
|
# Clamp dt into [eps, max_edge) to keep bucket indices valid.
|
|
|
|
|
eps = self.eps
|
|
|
|
|
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
|
|
|
|
|
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
|
|
|
|
|
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
|
|
|
|
|
|
|
|
|
|
hazards = F.softplus(logits_v) + eps # (Mv, K, B)
|
|
|
|
|
total_hazard = hazards.sum(dim=1) # (Mv, B)
|
|
|
|
|
|
|
|
|
|
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
|
|
|
|
|
widths = edges[1:] - edges[:-1] # (B,)
|
|
|
|
|
|
|
|
|
|
# Bin index b* in [0, B-1]. boundaries are edges[1:] (length B).
|
|
|
|
|
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,)
|
|
|
|
|
b_star = torch.clamp(b_star, min=0, max=B - 1)
|
|
|
|
|
|
|
|
|
|
ar = torch.arange(logits_v.size(0), device=device)
|
|
|
|
|
hazard_event = hazards[ar, target_v, b_star] # (Mv,)
|
|
|
|
|
|
|
|
|
|
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
|
|
|
|
weighted = total_hazard * widths.unsqueeze(0) # (Mv, B)
|
|
|
|
|
cum = weighted.cumsum(dim=1) # (Mv, B)
|
|
|
|
|
full_bins_int = torch.zeros_like(dt_v)
|
|
|
|
|
has_full = b_star > 0
|
|
|
|
|
if has_full.any():
|
|
|
|
|
full_bins_int[has_full] = cum.gather(
|
|
|
|
|
1, (b_star[has_full] - 1).unsqueeze(1)
|
|
|
|
|
).squeeze(1)
|
|
|
|
|
|
|
|
|
|
edge_left = edges[b_star] # (Mv,)
|
|
|
|
|
partial = total_hazard.gather(
|
|
|
|
|
1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left)
|
|
|
|
|
integral = full_bins_int + partial
|
|
|
|
|
|
|
|
|
|
nll_v = -torch.log(hazard_event) + integral
|
|
|
|
|
nll_full[idx] = nll_v
|
|
|
|
|
|
|
|
|
|
if reduction == "none":
|
|
|
|
|
nll_out = nll_full
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll_out = nll_v.sum()
|
|
|
|
|
elif reduction == "mean":
|
|
|
|
|
nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(())
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
|
|
|
|
|
|
|
|
|
reg = logits.new_zeros(())
|
|
|
|
|
|
|
|
|
|
if self.lambda_reg != 0.0:
|
|
|
|
|
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
|
|
|
|
|
|
|
|
|
|
return nll_out, reg
|
|
|
|
|
|
|
|
|
|
|
2026-01-07 21:32:00 +08:00
|
|
|
class WeibullNLLLoss(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Weibull hazard in t.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k}
|
|
|
|
|
|
|
|
|
|
\\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1}
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
eps (float): Small epsilon for numerical stability.
|
|
|
|
|
lambda_reg (float): Regularization weight.
|
|
|
|
|
use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples.
|
|
|
|
|
near_integer_eps_years (float): Near-integer threshold in years.
|
|
|
|
|
interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years.
|
|
|
|
|
min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
eps: float = 1e-6,
|
|
|
|
|
lambda_reg: float = 0.0,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.lambda_reg = lambda_reg
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
target_events: torch.Tensor,
|
|
|
|
|
dt: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Forward pass.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logits (torch.Tensor): (M, K, 2) tensor of logits.
|
|
|
|
|
target_events (torch.Tensor): (M,) tensor of target events.
|
|
|
|
|
dt (torch.Tensor): (M,) tensor of time intervals.
|
|
|
|
|
reduction (str): 'mean', 'sum', or 'none'.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization).
|
|
|
|
|
"""
|
|
|
|
|
shapes = F.softplus(logits[..., 0]) + self.eps # (M,K)
|
|
|
|
|
scales = F.softplus(logits[..., 1]) + self.eps # (M,K)
|
|
|
|
|
eps = self.eps
|
|
|
|
|
t = torch.clamp(dt, min=eps)
|
|
|
|
|
|
|
|
|
|
t_mat = t.unsqueeze(1) # (M,1)
|
|
|
|
|
|
|
|
|
|
# cumulative hazard (M,K)
|
|
|
|
|
cum_hazard = scales * t_mat.pow(shapes)
|
|
|
|
|
|
|
|
|
|
# hazard (M,K)
|
|
|
|
|
hazard = shapes * scales * t_mat.pow(shapes - 1.0)
|
|
|
|
|
|
|
|
|
|
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
|
|
|
|
# Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t))
|
|
|
|
|
# NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t)
|
|
|
|
|
nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1)
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
nll = nll.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
nll = nll.sum()
|
|
|
|
|
|
|
|
|
|
reg = shapes.new_zeros(())
|
|
|
|
|
if self.lambda_reg > 0:
|
|
|
|
|
reg = self.lambda_reg * (
|
|
|
|
|
(torch.log(scales + eps) ** 2).mean() +
|
|
|
|
|
(torch.log(shapes + eps) ** 2).mean()
|
|
|
|
|
)
|
2026-01-08 12:45:31 +08:00
|
|
|
return nll, reg
|