Refactor LogNormalBasisHazardLoss to LogNormalBasisBinnedHazardCIFNLLLoss and update related configurations

This commit is contained in:
2026-01-13 21:11:38 +08:00
parent 1df02d85d7
commit f16596ed58
3 changed files with 320 additions and 154 deletions

View File

@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
full_cov: bool
checkpoint_path: str
@@ -425,24 +425,24 @@ def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
def cifs_from_lognormal_basis_logits(
def cifs_from_lognormal_basis_binned_hazard_logits(
logits: torch.Tensor,
*,
centers: Sequence[float],
sigma: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
*,
bin_edges: Optional[Sequence[float]] = None,
eps: float = 1e-8,
alpha_floor: float = 0.0,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
"""Convert Route-3 binned hazard logits -> CIFs at taus.
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
centers: length R, in log-time.
sigma: scalar tensor (already clamped) in log-time units.
taus: horizons in the same units as training bin_edges (years).
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
taus are expected to align with finite bin edges.
"""
if logits.ndim != 2:
raise ValueError("Expected logits shape (B, 1+K*R)")
if logits.ndim not in {2, 3}:
raise ValueError("logits must be 2D or 3D")
if sigma.ndim != 0:
raise ValueError("sigma must be a scalar tensor")
@@ -452,37 +452,111 @@ def cifs_from_lognormal_basis_logits(
centers_t = torch.tensor([float(x)
for x in centers], device=device, dtype=dtype)
r = int(centers_t.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0 or (jr % r) != 0:
raise ValueError("logits.shape[1]-1 must be divisible by R")
k = jr // r
if r <= 0:
raise ValueError("centers must be non-empty")
offset = 0
if logits.ndim == 3:
j = int(logits.shape[1])
if int(logits.shape[2]) != r:
raise ValueError(
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
)
else:
d = int(logits.shape[1])
if d % r == 0:
jr = d
elif (d - 1) % r == 0:
offset = 1
jr = d - 1
else:
raise ValueError(
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}")
j = jr // r
if j <= 0:
raise ValueError("Inferred J must be >= 1")
edges = [float(x) for x in bin_edges]
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins <= 0:
raise ValueError("bin_edges must contain at least one finite edge")
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
left = torch.tensor(edges[:n_bins], device=device, dtype=dtype)
right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype)
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
if bin_edges is not None:
edges = [float(x) for x in bin_edges]
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
taus_t = torch.clamp(taus_t, min=t_min_t)
log_tau = torch.log(taus_t) # (H,)
left_is_zero = left <= 0
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
right_clamped = torch.clamp(right, min=t_min_t)
log_right = torch.log(right_clamped)
# (H,R)
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
cdf = _normal_cdf_stable(z)
sigma_c = sigma.to(device=device, dtype=dtype)
z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
w_all = torch.softmax(logits, dim=-1)
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
cdf_left = _normal_cdf_stable(z_left)
if left_is_zero.any():
cdf_left = torch.where(left_is_zero.unsqueeze(-1),
torch.zeros_like(cdf_left), cdf_left)
cdf_right = _normal_cdf_stable(z_right)
delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R)
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
if logits.ndim == 3:
alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R)
else:
logits_used = logits[:, offset:]
alpha = (F.softplus(logits_used) + float(alpha_floor)
).view(logits.size(0), j, r) # (B,J,R)
h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins)
h_k = h_jk.sum(dim=1) # (B,n_bins)
h_k = torch.clamp(h_k, min=eps)
h_jk = torch.clamp(h_jk, min=eps)
p_comp = torch.exp(-h_k) # (B,n_bins)
one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H)
ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps)
p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins)
ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype)
cum = torch.cumprod(p_comp, dim=1) # survival after each bin
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * p_event, dim=2) # (B,J,n_bins)
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
idx0 = int(np.argmin(diffs))
if diffs[idx0] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
)
tau_to_idx.append(idx0)
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
if not return_survival:
return cif
survival = 1.0 - cif.sum(dim=1) # (B,H)
survival = torch.clamp(survival, min=0.0, max=1.0)
survival = cum.index_select(dim=1, index=idx) # (B,H)
return cif, survival
@@ -1269,20 +1343,46 @@ def instantiate_model_and_head(
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "lognormal_basis_hazard":
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = cfg.get("lognormal_centers", None)
if centers is None:
centers = cfg.get("centers", None)
if not isinstance(centers, list) or len(centers) == 0:
raise ValueError(
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
"lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
)
out_dims = [1 + dataset.n_disease * len(centers)]
r = len(centers)
desired_total = int(dataset.n_disease) * int(r)
legacy_total = 1 + desired_total
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
# that used a single flattened dimension (1 + K*R).
out_dims = [int(dataset.n_disease), int(r)]
if checkpoint_path:
try:
ckpt = torch.load(checkpoint_path, map_location="cpu")
head_sd = ckpt.get("head_state_dict", {})
w = head_sd.get("net.2.weight", None)
if isinstance(w, torch.Tensor) and w.ndim == 2:
out_features = int(w.shape[0])
if out_features == legacy_total:
out_dims = [legacy_total]
elif out_features == desired_total:
out_dims = [int(dataset.n_disease), int(r)]
else:
raise ValueError(
f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)"
)
except Exception as e:
raise ValueError(
f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}"
)
loss_params["centers"] = centers
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0))
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
@@ -1399,18 +1499,20 @@ def predict_cifs_for_model(
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "lognormal_basis_hazard":
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = loss_params.get("centers", None)
sigma = loss_params.get("sigma", None)
if centers is None or sigma is None:
raise ValueError(
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_logits(
"lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits(
logits,
centers=centers,
sigma=sigma,
taus=eval_horizons,
bin_edges=bin_edges,
taus=eval_horizons,
eps=float(loss_params.get("loss_eps", 1e-8)),
alpha_floor=float(loss_params.get("alpha_floor", 0.0)),
return_survival=True,
)
else:
@@ -1927,7 +2029,7 @@ def main() -> int:
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if loss_type == "lognormal_basis_hazard":
if loss_type == "lognormal_basis_binned_hazard_cif":
crit_state = ckpt.get("criterion_state_dict", {})
log_sigma = crit_state.get("log_sigma", None)
if isinstance(log_sigma, torch.Tensor):

264
losses.py
View File

@@ -260,24 +260,27 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
return nll, reg
class LogNormalBasisHazardLoss(nn.Module):
"""Event-only competing risks loss using lognormal basis (Gaussian on log-time).
class LogNormalBasisBinnedHazardCIFNLLLoss(nn.Module):
r"""Route-3: continuous-time lognormal-basis hazards with discrete-time CIF likelihood.
This loss models cause-specific CIF as a mixture of lognormal basis CDFs:
This implements a cause-specific continuous-time hazard model:
F_j(t) = sum_r w_{j,r} * Phi((log t - mu_r) / sigma)
\lambda_j(t) = \sum_r \alpha_{j,r} b_r(t)
Training uses *bin probability mass* (Delta CIF / interval mass). There is
**no censoring**: every sample is an observed event with a valid cause.
where b_r(t) is the lognormal PDF basis implied by a Normal on log-time.
Logits interface:
logits: (B, 1 + J*R)
logits[:, 0] -> w0 (survival mass / never-event)
logits[:, 1:] -> flattened (j,r) in row-major order: j then r
index = 1 + j*R + r
Training objective is IDENTICAL in structure to DiscreteTimeCIFNLLLoss,
but per-bin categorical probabilities are derived from integrated hazards.
Expected logits interface (preferred):
logits: (B, J*R)
reshaped to (B, J, R)
For convenience/compatibility, also accepts:
logits: (B, 1+J*R) and ignores the first column.
Forward interface (must match):
forward(logits, target_events, dt, reduction)
forward(logits, target_events, dt, reduction) -> (nll, reg)
"""
def __init__(
@@ -286,17 +289,20 @@ class LogNormalBasisHazardLoss(nn.Module):
centers: Sequence[float],
*,
eps: float = 1e-8,
bandwidth_init: float = 0.5,
alpha_floor: float = 0.0,
bandwidth_init: float = 0.7,
bandwidth_min: float = 1e-3,
bandwidth_max: float = 10.0,
lambda_sigma_reg: float = 0.0,
sigma_reg_target: Optional[float] = None,
return_dict: bool = False,
lambda_reg: float = 0.0,
):
super().__init__()
if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2")
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
if float(bin_edges[0]) != 0.0:
raise ValueError("bin_edges[0] must equal 0")
# allow last edge to be +inf
for i in range(1, len(bin_edges)):
prev = float(bin_edges[i - 1])
@@ -310,20 +316,19 @@ class LogNormalBasisHazardLoss(nn.Module):
else:
if not (cur > prev):
raise ValueError("bin_edges must be strictly increasing")
if float(bin_edges[0]) < 0.0:
raise ValueError("bin_edges[0] must be >= 0")
if len(centers) < 1:
raise ValueError("centers must have length >= 1")
self.eps = float(eps)
self.alpha_floor = float(alpha_floor)
self.bandwidth_min = float(bandwidth_min)
self.bandwidth_max = float(bandwidth_max)
self.bandwidth_init = float(bandwidth_init)
self.lambda_sigma_reg = float(lambda_sigma_reg)
self.sigma_reg_target = None if sigma_reg_target is None else float(
sigma_reg_target)
self.bandwidth_init = float(bandwidth_init)
self.return_dict = bool(return_dict)
self.lambda_reg = float(lambda_reg)
self.register_buffer(
"bin_edges",
@@ -343,69 +348,32 @@ class LogNormalBasisHazardLoss(nn.Module):
@staticmethod
def _normal_cdf(z: torch.Tensor) -> torch.Tensor:
# Stable normal CDF via erf.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
@staticmethod
def _normal_sf(z: torch.Tensor) -> torch.Tensor:
# Stable normal survival function via erfc.
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * torch.erfc(z / math.sqrt(2.0))
def forward(
def _compute_delta_basis_all_bins(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, Any]]:
if logits.ndim != 2:
raise ValueError(
f"logits must be 2D with shape (B, 1+J*R); got {tuple(logits.shape)}")
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError("target_events and dt must be 1D tensors")
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch among logits, target_events, dt")
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
device = logits.device
dtype = logits.dtype
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute ΔB[k,r] for bins k=1..n_bins (shape: (n_bins, R))."""
bin_edges = self.bin_edges.to(device=device, dtype=dtype)
centers = self.centers.to(device=device, dtype=dtype)
bsz = logits.shape[0]
r = int(centers.numel())
jr = int(logits.shape[1] - 1)
if jr <= 0:
raise ValueError(
"logits.shape[1] must be >= 2 (w0 + at least one (j,r) weight)")
if jr % r != 0:
raise ValueError(
f"(logits.shape[1]-1) must be divisible by R={r}; got {jr}")
j = jr // r
# 1) Stable global weights (includes w0).
w_all = torch.softmax(logits, dim=-1) # (B, 1+J*R)
w0 = w_all[:, 0]
w = w_all[:, 1:].view(bsz, j, r)
# 2) Determine event bin index.
k = int(bin_edges.numel() - 1)
if k < 1:
n_bins = int(bin_edges.numel() - 1)
if n_bins < 1:
raise ValueError("bin_edges must define at least one bin")
# v2: dt is always continuous time (float), map to bin via searchsorted.
dt_f = dt.to(device=device, dtype=dtype)
bin_idx = torch.searchsorted(bin_edges, dt_f, right=True) - 1
bin_idx = torch.clamp(bin_idx, 0, k - 1).to(torch.long)
left = bin_edges[:-1] # (n_bins,)
right = bin_edges[1:] # (n_bins,)
left = bin_edges[bin_idx]
right = bin_edges[bin_idx + 1]
# 3) Stable log(t) clamp.
if float(self.bin_edges[1]) > 0.0:
t_min = float(self.bin_edges[1]) * 1e-6
else:
@@ -413,19 +381,19 @@ class LogNormalBasisHazardLoss(nn.Module):
t_min_t = torch.tensor(t_min, device=device, dtype=dtype)
left_is_zero = left <= 0
# For log() we still need a positive clamp, but we will treat CDF(left)=0 exactly
# when left<=0 (instead of approximating via t_min).
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
is_inf = torch.isinf(right)
# right might be +inf for last bin; avoid log(+inf) by substituting a safe finite value.
right_safe = torch.where(is_inf, left_clamped,
torch.clamp(right, min=t_min_t))
right_safe = torch.where(
is_inf, left_clamped, torch.clamp(right, min=t_min_t))
log_right = torch.log(right_safe)
sigma = torch.clamp(self.log_sigma.to(
device=device, dtype=dtype).exp(), self.bandwidth_min, self.bandwidth_max)
sigma = torch.clamp(
self.log_sigma.to(device=device, dtype=dtype).exp(),
self.bandwidth_min,
self.bandwidth_max,
)
z_left = (log_left.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
z_right = (log_right.unsqueeze(-1) - centers.unsqueeze(0)) / sigma
@@ -433,58 +401,152 @@ class LogNormalBasisHazardLoss(nn.Module):
z_right = torch.clamp(z_right, -12.0, 12.0)
cdf_left = self._normal_cdf(z_left)
# Treat the first-bin left boundary exactly as 0 in CDF.
if left_is_zero.any():
cdf_left = torch.where(
left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left)
cdf_right = self._normal_cdf(z_right)
delta_finite = cdf_right - cdf_left
# Last bin: ΔB = 1 - CDF(left) = SF(left), computed via erfc for stability.
delta_last = self._normal_sf(z_left)
# If left<=0, SF(left)=1 exactly.
if left_is_zero.any():
delta_last = torch.where(
left_is_zero.unsqueeze(-1), torch.ones_like(delta_last), delta_last)
delta_basis = torch.where(
is_inf.unsqueeze(-1), delta_last, delta_finite)
delta_basis = torch.clamp(delta_basis, min=0.0)
return delta_basis
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.ndim not in {2, 3}:
raise ValueError(
f"logits must be 2D (B, J*R) (or (B, 1+J*R)) or 3D (B, J, R); got {tuple(logits.shape)}"
)
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError(
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
)
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
)
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
if not torch.all(dt > 0):
raise ValueError("dt must be strictly positive")
device = logits.device
dtype = logits.dtype
centers = self.centers.to(device=device, dtype=dtype)
r = int(centers.numel())
if r < 1:
raise ValueError("centers must have length >= 1")
if logits.ndim == 3:
if logits.shape[2] != r:
raise ValueError(
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
)
j = int(logits.shape[1])
if j < 1:
raise ValueError("Inferred number of causes J must be >= 1")
alpha = F.softplus(logits) + self.alpha_floor # (B, J, R)
else:
d = int(logits.shape[1])
offset = 0
if d % r == 0:
jr = d
elif (d - 1) % r == 0:
offset = 1
jr = d - 1
else:
raise ValueError(
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}"
)
j = jr // r
if j < 1:
raise ValueError("Inferred number of causes J must be >= 1")
logits_used = logits[:, offset:]
alpha = F.softplus(logits_used).view(-1, j, r) + \
self.alpha_floor # (B, J, R)
delta_basis = self._compute_delta_basis_all_bins(
device=device,
dtype=dtype,
) # (n_bins, R)
n_bins = int(delta_basis.shape[0])
# H_{j,k} = sum_r alpha_{j,r} * ΔB_{k,r}
h_jk = torch.einsum("mjr,kr->mjk", alpha,
delta_basis) # (B, J, n_bins)
h_k = h_jk.sum(dim=1) # (B, n_bins)
# Map continuous dt to discrete event bin index k* in {1..n_bins}.
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
time_bin = torch.bucketize(dt, bin_edges)
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(torch.long)
# 4) Gather per-sample cause weights and compute event mass.
cause = target_events.to(device=device, dtype=torch.long)
if (cause < 0).any() or (cause >= j).any():
raise ValueError(f"target_events must be in [0, J-1] where J={j}")
b_idx = torch.arange(bsz, device=device)
w_cause = w[b_idx, cause, :] # (B, R)
# Previous survival term: sum_{u<k*} H_u
bins = torch.arange(
1, n_bins + 1, device=device).unsqueeze(0) # (1, n_bins)
mask_prev = bins < time_bin.unsqueeze(1) # (B, n_bins)
loss_prev = (h_k * mask_prev.to(h_k.dtype)).sum(dim=1) # (B,)
m = (w_cause * delta_basis).sum(dim=-1)
m = torch.clamp(m, min=self.eps)
nll_vec = -torch.log(m)
# Event term at k*: -log p_{k*}(cause)
b_idx = torch.arange(target_events.shape[0], device=device)
k0 = time_bin - 1 # (B,) index into 0..n_bins-1
h_event_total = h_k[b_idx, k0]
h_event_total = torch.clamp(h_event_total, min=self.eps)
h_event_cause = h_jk[b_idx, cause, k0]
h_event_cause = torch.clamp(h_event_cause, min=self.eps)
# log(1 - exp(-H)) stably
log1mexp = torch.log(-torch.expm1(-h_event_total))
loss_event = -log1mexp - \
torch.log(h_event_cause) + torch.log(h_event_total)
loss = loss_prev + loss_event
if reduction == "mean":
nll: torch.Tensor = nll_vec.mean()
nll = loss.mean()
elif reduction == "sum":
nll = nll_vec.sum()
nll = loss.sum()
else:
nll = nll_vec
nll = loss
reg = torch.zeros((), device=device, dtype=dtype)
if self.lambda_reg > 0.0:
# Regularize the within-bin cause competition via NLL on log ratios.
# ratio_j = H_{j,k*} / H_{k*}
h_event_all = h_jk[b_idx, :, k0] # (B, J)
denom = torch.clamp(h_event_total, min=self.eps).unsqueeze(1)
ratio = torch.clamp(h_event_all / denom, min=self.eps)
log_ratio = torch.log(ratio)
reg = reg + self.lambda_reg * F.nll_loss(
log_ratio, cause, reduction="mean")
sigma_penalty = torch.zeros((), device=device, dtype=dtype)
if self.lambda_sigma_reg > 0.0:
target = self.bandwidth_init if self.sigma_reg_target is None else self.sigma_reg_target
sigma_penalty = (self.log_sigma.to(
device=device, dtype=dtype) - math.log(float(target))) ** 2
reg = sigma_penalty * float(self.lambda_sigma_reg)
reg = reg + sigma_penalty * self.lambda_sigma_reg
if not self.return_dict:
return nll, reg
return {
"nll": nll,
"reg": reg,
"nll_vec": nll_vec,
"sigma": sigma.detach(),
"avg_w0": w0.mean().detach(),
"min_delta_basis": delta_basis.min().detach(),
"mean_m": m.mean().detach(),
"sigma_penalty": sigma_penalty.detach(),
"bin_idx": bin_idx.detach(),
}
return nll, reg

View File

@@ -1,4 +1,4 @@
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
@@ -22,7 +22,7 @@ class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif',
'lognormal_basis_hazard'] = 'exponential'
'lognormal_basis_binned_hazard_cif'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -34,7 +34,7 @@ class TrainConfig:
default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
)
# LogNormalBasisHazardLoss specific
# LogNormal basis (shared by Route-3 binned hazard)
lognormal_centers: Optional[Sequence[float]] = field(
default_factory=list) # mu_r in log-time
loss_eps: float = 1e-8
@@ -73,7 +73,8 @@ def parse_args() -> TrainConfig:
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
choices=['exponential', 'discrete_time_cif',
'lognormal_basis_binned_hazard_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
@@ -93,17 +94,17 @@ def parse_args() -> TrainConfig:
type=float,
nargs='*',
default=None,
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
parser.add_argument("--loss_eps", type=float, default=1e-8,
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
help="Epsilon for log clamps in lognormal-basis losses.")
parser.add_argument("--bandwidth_init", type=float, default=0.7,
help="Initial sigma for LogNormalBasisHazardLoss.")
help="Initial sigma for lognormal-basis.")
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
help="Minimum sigma clamp for lognormal-basis.")
parser.add_argument("--bandwidth_max", type=float, default=10.0,
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
help="Maximum sigma clamp for lognormal-basis.")
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
help="Sigma regularization strength for lognormal-basis.")
parser.add_argument("--sigma_reg_target", type=float, default=None,
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
parser.add_argument("--rank", type=int, default=16,
@@ -261,12 +262,12 @@ class Trainer:
).to(self.device)
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "lognormal_basis_hazard":
elif cfg.loss_type == "lognormal_basis_binned_hazard_cif":
r = len(cfg.lognormal_centers)
if r <= 0:
raise ValueError(
"lognormal_centers must be non-empty for lognormal_basis_hazard")
self.criterion = LogNormalBasisHazardLoss(
"lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif")
self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss(
bin_edges=cfg.bin_edges,
centers=cfg.lognormal_centers,
eps=cfg.loss_eps,
@@ -275,9 +276,10 @@ class Trainer:
bandwidth_max=cfg.bandwidth_max,
lambda_sigma_reg=cfg.lambda_sigma_reg,
sigma_reg_target=cfg.sigma_reg_target,
lambda_reg=cfg.lambda_reg,
).to(self.device)
# logits shape (M, 1 + J*R)
out_dims = [1 + dataset.n_disease * r]
# Head emits (M, J, R) for Route-3.
out_dims = [dataset.n_disease, r]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")