Refactor LogNormalBasisHazardLoss to LogNormalBasisBinnedHazardCIFNLLLoss and update related configurations

This commit is contained in:
2026-01-13 21:11:38 +08:00
parent 1df02d85d7
commit f16596ed58
3 changed files with 320 additions and 154 deletions

View File

@@ -1,4 +1,4 @@
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
@@ -22,7 +22,7 @@ class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif',
'lognormal_basis_hazard'] = 'exponential'
'lognormal_basis_binned_hazard_cif'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -34,7 +34,7 @@ class TrainConfig:
default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
)
# LogNormalBasisHazardLoss specific
# LogNormal basis (shared by Route-3 binned hazard)
lognormal_centers: Optional[Sequence[float]] = field(
default_factory=list) # mu_r in log-time
loss_eps: float = 1e-8
@@ -73,7 +73,8 @@ def parse_args() -> TrainConfig:
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
choices=['exponential', 'discrete_time_cif',
'lognormal_basis_binned_hazard_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
@@ -93,17 +94,17 @@ def parse_args() -> TrainConfig:
type=float,
nargs='*',
default=None,
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
parser.add_argument("--loss_eps", type=float, default=1e-8,
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
help="Epsilon for log clamps in lognormal-basis losses.")
parser.add_argument("--bandwidth_init", type=float, default=0.7,
help="Initial sigma for LogNormalBasisHazardLoss.")
help="Initial sigma for lognormal-basis.")
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
help="Minimum sigma clamp for lognormal-basis.")
parser.add_argument("--bandwidth_max", type=float, default=10.0,
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
help="Maximum sigma clamp for lognormal-basis.")
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
help="Sigma regularization strength for lognormal-basis.")
parser.add_argument("--sigma_reg_target", type=float, default=None,
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
parser.add_argument("--rank", type=int, default=16,
@@ -261,12 +262,12 @@ class Trainer:
).to(self.device)
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "lognormal_basis_hazard":
elif cfg.loss_type == "lognormal_basis_binned_hazard_cif":
r = len(cfg.lognormal_centers)
if r <= 0:
raise ValueError(
"lognormal_centers must be non-empty for lognormal_basis_hazard")
self.criterion = LogNormalBasisHazardLoss(
"lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif")
self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss(
bin_edges=cfg.bin_edges,
centers=cfg.lognormal_centers,
eps=cfg.loss_eps,
@@ -275,9 +276,10 @@ class Trainer:
bandwidth_max=cfg.bandwidth_max,
lambda_sigma_reg=cfg.lambda_sigma_reg,
sigma_reg_target=cfg.sigma_reg_target,
lambda_reg=cfg.lambda_reg,
).to(self.device)
# logits shape (M, 1 + J*R)
out_dims = [1 + dataset.n_disease * r]
# Head emits (M, J, R) for Route-3.
out_dims = [dataset.n_disease, r]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")