Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.

This commit is contained in:
2026-01-09 18:31:38 +08:00
parent 880fd53a4b
commit 209dde2299
3 changed files with 172 additions and 349 deletions

154
train.py
View File

@@ -1,5 +1,5 @@
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
@@ -22,8 +22,7 @@ from typing import Literal, Sequence
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull',
'piecewise_exponential'] = 'weibull'
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
@@ -32,7 +31,8 @@ class TrainConfig:
pdrop: float = 0.1
lambda_reg: float = 1e-4
bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
)
rank: int = 16
# SapDelphi specific
@@ -61,8 +61,12 @@ def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train Delphi Model")
parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument("--loss_type", type=str, choices=[
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120,
@@ -193,18 +197,14 @@ class Trainer:
self.criterion = ExponentialNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 1
elif cfg.loss_type == "piecewise_exponential":
self.criterion = PiecewiseExponentialLoss(
out_dims = [dataset.n_disease]
elif cfg.loss_type == "discrete_time_cif":
self.criterion = DiscreteTimeCIFNLLLoss(
bin_edges=cfg.bin_edges,
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = len(cfg.bin_edges) - 1
elif cfg.loss_type == "weibull":
self.criterion = WeibullNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 2
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
@@ -217,8 +217,6 @@ class Trainer:
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
@@ -232,8 +230,6 @@ class Trainer:
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
@@ -242,10 +238,25 @@ class Trainer:
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
# Prediction head maps context vectors -> logits with the shape required by the loss.
self.head = SimpleHead(
n_embd=cfg.n_embd,
out_dims=out_dims,
).to(self.device)
print(f"Model initialized: {cfg.model_type}")
print(f"Number of trainable parameters: {get_num_params(self.model)}")
print(
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
print(
f"Number of trainable parameters (head): {get_num_params(self.head)}")
self._optim_params = (
list(self.model.parameters())
+ list(self.head.parameters())
)
self.optimizer = AdamW(
self.model.parameters(),
self._optim_params,
lr=cfg.max_lr,
weight_decay=cfg.weight_decay,
betas=(0.9, 0.99),
@@ -293,23 +304,11 @@ class Trainer:
best_val_score = float('inf')
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
model_for_logging = self.model.module if hasattr(
self.model, "module") else self.model
delta_scale = None
theta_proj = getattr(model_for_logging, "theta_proj", None)
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
try:
delta_scale = float(
theta_proj.delta_scale.detach().cpu().item())
except Exception:
delta_scale = None
self.model.train()
self.head.train()
total_train_pairs = 0
total_train_nll = 0.0
total_train_reg = 0.0
total_train_log_scale_sq = 0.0
total_train_log_shape_sq = 0.0
pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
@@ -334,25 +333,17 @@ class Trainer:
self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step)
self.optimizer.param_groups[0]['lr'] = lr
logits = self.model(
h = self.model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
b_prev=b_prev,
t_prev=t_prev,
)
if isinstance(self.criterion, WeibullNLLLoss):
eps = float(self.criterion.eps)
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
# Context vectors for selected previous events
c = h[b_prev, t_prev] # (M, D)
logits = self.head(c)
target_event = event_seq[b_next, t_next] - 2
nll_vec, reg = self.criterion(
@@ -367,10 +358,6 @@ class Trainer:
total_train_pairs += num_pairs
total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({
@@ -380,8 +367,7 @@ class Trainer:
})
loss.backward()
if self.cfg.grad_clip > 0:
clip_grad_norm_(self.model.parameters(),
self.cfg.grad_clip)
clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
self.optimizer.step()
self.global_step += 1
@@ -391,23 +377,12 @@ class Trainer:
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
train_log_scale_sq = (
total_train_log_scale_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
train_log_shape_sq = (
total_train_log_shape_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
self.model.eval()
self.head.eval()
total_val_pairs = 0
total_val_nll = 0.0
total_val_reg = 0.0
total_val_log_scale_sq = 0.0
total_val_log_shape_sq = 0.0
with torch.no_grad():
val_pbar = tqdm(self.val_loader, desc="Validation")
for batch in val_pbar:
@@ -428,27 +403,16 @@ class Trainer:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
logits = self.model(
h = self.model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
b_prev=b_prev,
t_prev=t_prev
)
if isinstance(self.criterion, WeibullNLLLoss):
eps = float(self.criterion.eps)
shapes = torch.nn.functional.softplus(
logits[..., 0]) + eps
scales = torch.nn.functional.softplus(
logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
c = h[b_prev, t_prev]
logits = self.head(c)
target_events = event_seq[b_next, t_next] - 2
nll, reg = self.criterion(
@@ -460,10 +424,6 @@ class Trainer:
batch_nll_sum = nll.sum().item()
total_val_nll += batch_nll_sum
total_val_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
total_val_pairs += num_pairs
current_val_avg_nll = total_val_nll / \
@@ -478,16 +438,6 @@ class Trainer:
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
val_log_scale_sq = (
total_val_log_scale_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
val_log_shape_sq = (
total_val_log_shape_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
history.append({
"epoch": epoch,
@@ -495,11 +445,6 @@ class Trainer:
"train_reg": train_reg,
"val_nll": val_nll,
"val_reg": val_reg,
"delta_scale": delta_scale,
"train_log_scale_sq": train_log_scale_sq,
"train_log_shape_sq": train_log_shape_sq,
"val_log_scale_sq": val_log_scale_sq,
"val_log_shape_sq": val_log_shape_sq,
})
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
@@ -507,18 +452,6 @@ class Trainer:
tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}")
if delta_scale is not None:
tqdm.write(f" Delta scale: {delta_scale:.6g}")
if train_log_scale_sq is not None and train_log_shape_sq is not None:
tqdm.write(
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
tqdm.write(
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
if val_log_scale_sq is not None and val_log_shape_sq is not None:
tqdm.write(
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
tqdm.write(
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4)
@@ -533,6 +466,7 @@ class Trainer:
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"head_state_dict": self.head.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path)