Refactor LogNormalBasisHazardLoss to LogNormalBasisBinnedHazardCIFNLLLoss and update related configurations
This commit is contained in:
32
train.py
32
train.py
@@ -1,4 +1,4 @@
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisHazardLoss, get_valid_pairs_and_dt
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
@@ -22,7 +22,7 @@ class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif',
|
||||
'lognormal_basis_hazard'] = 'exponential'
|
||||
'lognormal_basis_binned_hazard_cif'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -34,7 +34,7 @@ class TrainConfig:
|
||||
default_factory=lambda: [0.0, 0.24, 0.72,
|
||||
1.61, 3.84, 10.0, 31.0, float('inf')]
|
||||
)
|
||||
# LogNormalBasisHazardLoss specific
|
||||
# LogNormal basis (shared by Route-3 binned hazard)
|
||||
lognormal_centers: Optional[Sequence[float]] = field(
|
||||
default_factory=list) # mu_r in log-time
|
||||
loss_eps: float = 1e-8
|
||||
@@ -73,7 +73,8 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif', 'lognormal_basis_hazard'],
|
||||
choices=['exponential', 'discrete_time_cif',
|
||||
'lognormal_basis_binned_hazard_cif'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
@@ -93,17 +94,17 @@ def parse_args() -> TrainConfig:
|
||||
type=float,
|
||||
nargs='*',
|
||||
default=None,
|
||||
help="LogNormalBasisHazardLoss centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
|
||||
help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
|
||||
parser.add_argument("--loss_eps", type=float, default=1e-8,
|
||||
help="Epsilon for LogNormalBasisHazardLoss log clamp.")
|
||||
help="Epsilon for log clamps in lognormal-basis losses.")
|
||||
parser.add_argument("--bandwidth_init", type=float, default=0.7,
|
||||
help="Initial sigma for LogNormalBasisHazardLoss.")
|
||||
help="Initial sigma for lognormal-basis.")
|
||||
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
|
||||
help="Minimum sigma clamp for LogNormalBasisHazardLoss.")
|
||||
help="Minimum sigma clamp for lognormal-basis.")
|
||||
parser.add_argument("--bandwidth_max", type=float, default=10.0,
|
||||
help="Maximum sigma clamp for LogNormalBasisHazardLoss.")
|
||||
help="Maximum sigma clamp for lognormal-basis.")
|
||||
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
|
||||
help="Sigma regularization strength for LogNormalBasisHazardLoss.")
|
||||
help="Sigma regularization strength for lognormal-basis.")
|
||||
parser.add_argument("--sigma_reg_target", type=float, default=None,
|
||||
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
|
||||
parser.add_argument("--rank", type=int, default=16,
|
||||
@@ -261,12 +262,12 @@ class Trainer:
|
||||
).to(self.device)
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
elif cfg.loss_type == "lognormal_basis_hazard":
|
||||
elif cfg.loss_type == "lognormal_basis_binned_hazard_cif":
|
||||
r = len(cfg.lognormal_centers)
|
||||
if r <= 0:
|
||||
raise ValueError(
|
||||
"lognormal_centers must be non-empty for lognormal_basis_hazard")
|
||||
self.criterion = LogNormalBasisHazardLoss(
|
||||
"lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif")
|
||||
self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss(
|
||||
bin_edges=cfg.bin_edges,
|
||||
centers=cfg.lognormal_centers,
|
||||
eps=cfg.loss_eps,
|
||||
@@ -275,9 +276,10 @@ class Trainer:
|
||||
bandwidth_max=cfg.bandwidth_max,
|
||||
lambda_sigma_reg=cfg.lambda_sigma_reg,
|
||||
sigma_reg_target=cfg.sigma_reg_target,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
# logits shape (M, 1 + J*R)
|
||||
out_dims = [1 + dataset.n_disease * r]
|
||||
# Head emits (M, J, R) for Route-3.
|
||||
out_dims = [dataset.n_disease, r]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user