Refactor LogNormalBasisHazardLoss to LogNormalBasisBinnedHazardCIFNLLLoss and update related configurations
This commit is contained in:
@@ -35,7 +35,7 @@ DEFAULT_DEATH_CAUSE_ID = 1256
|
||||
class ModelSpec:
|
||||
name: str
|
||||
model_type: str # delphi_fork | sap_delphi
|
||||
loss_type: str # exponential | discrete_time_cif | lognormal_basis_hazard
|
||||
loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
|
||||
full_cov: bool
|
||||
checkpoint_path: str
|
||||
|
||||
@@ -425,24 +425,24 @@ def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def cifs_from_lognormal_basis_logits(
|
||||
def cifs_from_lognormal_basis_binned_hazard_logits(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
centers: Sequence[float],
|
||||
sigma: torch.Tensor,
|
||||
bin_edges: Sequence[float],
|
||||
taus: Sequence[float],
|
||||
*,
|
||||
bin_edges: Optional[Sequence[float]] = None,
|
||||
eps: float = 1e-8,
|
||||
alpha_floor: float = 0.0,
|
||||
return_survival: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Convert LogNormalBasisHazardLoss logits -> CIFs at taus.
|
||||
"""Convert Route-3 binned hazard logits -> CIFs at taus.
|
||||
|
||||
logits: (B, 1 + K*R) where K is number of diseases (causes) and R is number of basis functions.
|
||||
centers: length R, in log-time.
|
||||
sigma: scalar tensor (already clamped) in log-time units.
|
||||
taus: horizons in the same units as training bin_edges (years).
|
||||
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
|
||||
taus are expected to align with finite bin edges.
|
||||
"""
|
||||
if logits.ndim != 2:
|
||||
raise ValueError("Expected logits shape (B, 1+K*R)")
|
||||
if logits.ndim not in {2, 3}:
|
||||
raise ValueError("logits must be 2D or 3D")
|
||||
if sigma.ndim != 0:
|
||||
raise ValueError("sigma must be a scalar tensor")
|
||||
|
||||
@@ -452,37 +452,111 @@ def cifs_from_lognormal_basis_logits(
|
||||
centers_t = torch.tensor([float(x)
|
||||
for x in centers], device=device, dtype=dtype)
|
||||
r = int(centers_t.numel())
|
||||
jr = int(logits.shape[1] - 1)
|
||||
if jr <= 0 or (jr % r) != 0:
|
||||
raise ValueError("logits.shape[1]-1 must be divisible by R")
|
||||
k = jr // r
|
||||
if r <= 0:
|
||||
raise ValueError("centers must be non-empty")
|
||||
|
||||
offset = 0
|
||||
if logits.ndim == 3:
|
||||
j = int(logits.shape[1])
|
||||
if int(logits.shape[2]) != r:
|
||||
raise ValueError(
|
||||
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
|
||||
)
|
||||
else:
|
||||
d = int(logits.shape[1])
|
||||
if d % r == 0:
|
||||
jr = d
|
||||
elif (d - 1) % r == 0:
|
||||
offset = 1
|
||||
jr = d - 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}")
|
||||
|
||||
j = jr // r
|
||||
|
||||
if j <= 0:
|
||||
raise ValueError("Inferred J must be >= 1")
|
||||
|
||||
edges = [float(x) for x in bin_edges]
|
||||
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
|
||||
n_bins = len(finite_edges)
|
||||
if n_bins <= 0:
|
||||
raise ValueError("bin_edges must contain at least one finite edge")
|
||||
|
||||
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
|
||||
left = torch.tensor(edges[:n_bins], device=device, dtype=dtype)
|
||||
right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype)
|
||||
|
||||
# Stable t_min clamp (aligns with training loss rule).
|
||||
t_min = 1e-12
|
||||
if bin_edges is not None:
|
||||
edges = [float(x) for x in bin_edges]
|
||||
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
|
||||
t_min = edges[1] * 1e-6
|
||||
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
|
||||
t_min = edges[1] * 1e-6
|
||||
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
|
||||
|
||||
taus_t = torch.tensor([float(x) for x in taus], device=device, dtype=dtype)
|
||||
taus_t = torch.clamp(taus_t, min=t_min_t)
|
||||
log_tau = torch.log(taus_t) # (H,)
|
||||
left_is_zero = left <= 0
|
||||
left_clamped = torch.clamp(left, min=t_min_t)
|
||||
log_left = torch.log(left_clamped)
|
||||
right_clamped = torch.clamp(right, min=t_min_t)
|
||||
log_right = torch.log(right_clamped)
|
||||
|
||||
# (H,R)
|
||||
z = (log_tau.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma
|
||||
cdf = _normal_cdf_stable(z)
|
||||
sigma_c = sigma.to(device=device, dtype=dtype)
|
||||
z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
|
||||
z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
|
||||
|
||||
w_all = torch.softmax(logits, dim=-1)
|
||||
w = w_all[:, 1:].view(logits.size(0), k, r) # (B,K,R)
|
||||
cdf_left = _normal_cdf_stable(z_left)
|
||||
if left_is_zero.any():
|
||||
cdf_left = torch.where(left_is_zero.unsqueeze(-1),
|
||||
torch.zeros_like(cdf_left), cdf_left)
|
||||
cdf_right = _normal_cdf_stable(z_right)
|
||||
delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R)
|
||||
|
||||
cif = torch.einsum("bkr,hr->bkh", w, cdf) # (B,K,H)
|
||||
if logits.ndim == 3:
|
||||
alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R)
|
||||
else:
|
||||
logits_used = logits[:, offset:]
|
||||
alpha = (F.softplus(logits_used) + float(alpha_floor)
|
||||
).view(logits.size(0), j, r) # (B,J,R)
|
||||
|
||||
h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins)
|
||||
h_k = h_jk.sum(dim=1) # (B,n_bins)
|
||||
|
||||
h_k = torch.clamp(h_k, min=eps)
|
||||
h_jk = torch.clamp(h_jk, min=eps)
|
||||
|
||||
p_comp = torch.exp(-h_k) # (B,n_bins)
|
||||
one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H)
|
||||
ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps)
|
||||
p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins)
|
||||
|
||||
ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype)
|
||||
cum = torch.cumprod(p_comp, dim=1) # survival after each bin
|
||||
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin
|
||||
|
||||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||||
1) * p_event, dim=2) # (B,J,n_bins)
|
||||
|
||||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||||
tau_to_idx: List[int] = []
|
||||
for tau in taus:
|
||||
tau_f = float(tau)
|
||||
if not math.isfinite(tau_f):
|
||||
raise ValueError("taus must be finite for discrete-time CIF")
|
||||
diffs = np.abs(finite_edges_arr - tau_f)
|
||||
idx0 = int(np.argmin(diffs))
|
||||
if diffs[idx0] > 1e-6:
|
||||
raise ValueError(
|
||||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
|
||||
)
|
||||
tau_to_idx.append(idx0)
|
||||
|
||||
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
|
||||
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
|
||||
|
||||
if not return_survival:
|
||||
return cif
|
||||
|
||||
survival = 1.0 - cif.sum(dim=1) # (B,H)
|
||||
survival = torch.clamp(survival, min=0.0, max=1.0)
|
||||
survival = cum.index_select(dim=1, index=idx) # (B,H)
|
||||
return cif, survival
|
||||
|
||||
|
||||
@@ -1269,20 +1343,46 @@ def instantiate_model_and_head(
|
||||
elif loss_type == "discrete_time_cif":
|
||||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
elif loss_type == "lognormal_basis_binned_hazard_cif":
|
||||
centers = cfg.get("lognormal_centers", None)
|
||||
if centers is None:
|
||||
centers = cfg.get("centers", None)
|
||||
if not isinstance(centers, list) or len(centers) == 0:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
|
||||
"lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
|
||||
)
|
||||
out_dims = [1 + dataset.n_disease * len(centers)]
|
||||
r = len(centers)
|
||||
desired_total = int(dataset.n_disease) * int(r)
|
||||
legacy_total = 1 + desired_total
|
||||
|
||||
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
|
||||
# that used a single flattened dimension (1 + K*R).
|
||||
out_dims = [int(dataset.n_disease), int(r)]
|
||||
if checkpoint_path:
|
||||
try:
|
||||
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
||||
head_sd = ckpt.get("head_state_dict", {})
|
||||
w = head_sd.get("net.2.weight", None)
|
||||
if isinstance(w, torch.Tensor) and w.ndim == 2:
|
||||
out_features = int(w.shape[0])
|
||||
if out_features == legacy_total:
|
||||
out_dims = [legacy_total]
|
||||
elif out_features == desired_total:
|
||||
out_dims = [int(dataset.n_disease), int(r)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}"
|
||||
)
|
||||
loss_params["centers"] = centers
|
||||
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
|
||||
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
|
||||
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
|
||||
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
|
||||
loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0))
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||||
|
||||
@@ -1399,18 +1499,20 @@ def predict_cifs_for_model(
|
||||
cif_full, survival = cifs_from_discrete_time_logits(
|
||||
# (B,K,H), (B,H)
|
||||
logits, bin_edges, eval_horizons, return_survival=True)
|
||||
elif loss_type == "lognormal_basis_hazard":
|
||||
elif loss_type == "lognormal_basis_binned_hazard_cif":
|
||||
centers = loss_params.get("centers", None)
|
||||
sigma = loss_params.get("sigma", None)
|
||||
if centers is None or sigma is None:
|
||||
raise ValueError(
|
||||
"lognormal_basis_hazard requires loss_params['centers'] and loss_params['sigma']")
|
||||
cif_full, survival = cifs_from_lognormal_basis_logits(
|
||||
"lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']")
|
||||
cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits(
|
||||
logits,
|
||||
centers=centers,
|
||||
sigma=sigma,
|
||||
taus=eval_horizons,
|
||||
bin_edges=bin_edges,
|
||||
taus=eval_horizons,
|
||||
eps=float(loss_params.get("loss_eps", 1e-8)),
|
||||
alpha_floor=float(loss_params.get("alpha_floor", 0.0)),
|
||||
return_survival=True,
|
||||
)
|
||||
else:
|
||||
@@ -1927,7 +2029,7 @@ def main() -> int:
|
||||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||||
|
||||
if loss_type == "lognormal_basis_hazard":
|
||||
if loss_type == "lognormal_basis_binned_hazard_cif":
|
||||
crit_state = ckpt.get("criterion_state_dict", {})
|
||||
log_sigma = crit_state.get("log_sigma", None)
|
||||
if isinstance(log_sigma, torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user