Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.
This commit is contained in:
154
train.py
154
train.py
@@ -1,5 +1,5 @@
|
||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@@ -22,8 +22,7 @@ from typing import Literal, Sequence
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'weibull',
|
||||
'piecewise_exponential'] = 'weibull'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -32,7 +31,8 @@ class TrainConfig:
|
||||
pdrop: float = 0.1
|
||||
lambda_reg: float = 1e-4
|
||||
bin_edges: Sequence[float] = field(
|
||||
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
|
||||
default_factory=lambda: [0.0, 0.24, 0.72,
|
||||
1.61, 3.84, 10.0, 31.0, float('inf')]
|
||||
)
|
||||
rank: int = 16
|
||||
# SapDelphi specific
|
||||
@@ -61,8 +61,12 @@ def parse_args() -> TrainConfig:
|
||||
parser = argparse.ArgumentParser(description="Train Delphi Model")
|
||||
parser.add_argument("--model_type", type=str, choices=[
|
||||
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
||||
parser.add_argument("--loss_type", type=str, choices=[
|
||||
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||
parser.add_argument("--n_embd", type=int, default=120,
|
||||
@@ -193,18 +197,14 @@ class Trainer:
|
||||
self.criterion = ExponentialNLLLoss(
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = 1
|
||||
elif cfg.loss_type == "piecewise_exponential":
|
||||
self.criterion = PiecewiseExponentialLoss(
|
||||
out_dims = [dataset.n_disease]
|
||||
elif cfg.loss_type == "discrete_time_cif":
|
||||
self.criterion = DiscreteTimeCIFNLLLoss(
|
||||
bin_edges=cfg.bin_edges,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = len(cfg.bin_edges) - 1
|
||||
elif cfg.loss_type == "weibull":
|
||||
self.criterion = WeibullNLLLoss(
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = 2
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
@@ -217,8 +217,6 @@ class Trainer:
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
@@ -232,8 +230,6 @@ class Trainer:
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
@@ -242,10 +238,25 @@ class Trainer:
|
||||
).to(self.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||
|
||||
# Prediction head maps context vectors -> logits with the shape required by the loss.
|
||||
self.head = SimpleHead(
|
||||
n_embd=cfg.n_embd,
|
||||
out_dims=out_dims,
|
||||
).to(self.device)
|
||||
|
||||
print(f"Model initialized: {cfg.model_type}")
|
||||
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
||||
print(
|
||||
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
|
||||
print(
|
||||
f"Number of trainable parameters (head): {get_num_params(self.head)}")
|
||||
|
||||
self._optim_params = (
|
||||
list(self.model.parameters())
|
||||
+ list(self.head.parameters())
|
||||
)
|
||||
self.optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
self._optim_params,
|
||||
lr=cfg.max_lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
betas=(0.9, 0.99),
|
||||
@@ -293,23 +304,11 @@ class Trainer:
|
||||
best_val_score = float('inf')
|
||||
patience_counter = 0
|
||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||
model_for_logging = self.model.module if hasattr(
|
||||
self.model, "module") else self.model
|
||||
delta_scale = None
|
||||
theta_proj = getattr(model_for_logging, "theta_proj", None)
|
||||
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
|
||||
try:
|
||||
delta_scale = float(
|
||||
theta_proj.delta_scale.detach().cpu().item())
|
||||
except Exception:
|
||||
delta_scale = None
|
||||
|
||||
self.model.train()
|
||||
self.head.train()
|
||||
total_train_pairs = 0
|
||||
total_train_nll = 0.0
|
||||
total_train_reg = 0.0
|
||||
total_train_log_scale_sq = 0.0
|
||||
total_train_log_shape_sq = 0.0
|
||||
pbar = tqdm(self.train_loader,
|
||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||
batch_count = 0
|
||||
@@ -334,25 +333,17 @@ class Trainer:
|
||||
self.optimizer.zero_grad()
|
||||
lr = self.compute_lr(self.global_step)
|
||||
self.optimizer.param_groups[0]['lr'] = lr
|
||||
logits = self.model(
|
||||
h = self.model(
|
||||
event_seq,
|
||||
time_seq,
|
||||
sexes,
|
||||
cont_feats,
|
||||
cate_feats,
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev,
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
# Context vectors for selected previous events
|
||||
c = h[b_prev, t_prev] # (M, D)
|
||||
logits = self.head(c)
|
||||
|
||||
target_event = event_seq[b_next, t_next] - 2
|
||||
nll_vec, reg = self.criterion(
|
||||
@@ -367,10 +358,6 @@ class Trainer:
|
||||
total_train_pairs += num_pairs
|
||||
total_train_nll += nll_vec.sum().item()
|
||||
total_train_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
avg_train_nll = total_train_nll / total_train_pairs
|
||||
avg_train_reg = total_train_reg / total_train_pairs
|
||||
pbar.set_postfix({
|
||||
@@ -380,8 +367,7 @@ class Trainer:
|
||||
})
|
||||
loss.backward()
|
||||
if self.cfg.grad_clip > 0:
|
||||
clip_grad_norm_(self.model.parameters(),
|
||||
self.cfg.grad_clip)
|
||||
clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
|
||||
self.optimizer.step()
|
||||
self.global_step += 1
|
||||
|
||||
@@ -391,23 +377,12 @@ class Trainer:
|
||||
|
||||
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_log_scale_sq = (
|
||||
total_train_log_scale_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
train_log_shape_sq = (
|
||||
total_train_log_shape_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
self.head.eval()
|
||||
total_val_pairs = 0
|
||||
total_val_nll = 0.0
|
||||
total_val_reg = 0.0
|
||||
total_val_log_scale_sq = 0.0
|
||||
total_val_log_shape_sq = 0.0
|
||||
with torch.no_grad():
|
||||
val_pbar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in val_pbar:
|
||||
@@ -428,27 +403,16 @@ class Trainer:
|
||||
continue
|
||||
dt, b_prev, t_prev, b_next, t_next = res
|
||||
num_pairs = dt.size(0)
|
||||
logits = self.model(
|
||||
h = self.model(
|
||||
event_seq,
|
||||
time_seq,
|
||||
sexes,
|
||||
cont_feats,
|
||||
cate_feats,
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(
|
||||
logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(
|
||||
logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
c = h[b_prev, t_prev]
|
||||
logits = self.head(c)
|
||||
|
||||
target_events = event_seq[b_next, t_next] - 2
|
||||
nll, reg = self.criterion(
|
||||
@@ -460,10 +424,6 @@ class Trainer:
|
||||
batch_nll_sum = nll.sum().item()
|
||||
total_val_nll += batch_nll_sum
|
||||
total_val_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
total_val_pairs += num_pairs
|
||||
|
||||
current_val_avg_nll = total_val_nll / \
|
||||
@@ -478,16 +438,6 @@ class Trainer:
|
||||
|
||||
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_log_scale_sq = (
|
||||
total_val_log_scale_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
val_log_shape_sq = (
|
||||
total_val_log_shape_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
history.append({
|
||||
"epoch": epoch,
|
||||
@@ -495,11 +445,6 @@ class Trainer:
|
||||
"train_reg": train_reg,
|
||||
"val_nll": val_nll,
|
||||
"val_reg": val_reg,
|
||||
"delta_scale": delta_scale,
|
||||
"train_log_scale_sq": train_log_scale_sq,
|
||||
"train_log_shape_sq": train_log_shape_sq,
|
||||
"val_log_scale_sq": val_log_scale_sq,
|
||||
"val_log_shape_sq": val_log_shape_sq,
|
||||
})
|
||||
|
||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||
@@ -507,18 +452,6 @@ class Trainer:
|
||||
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
||||
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||
if delta_scale is not None:
|
||||
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||
if train_log_scale_sq is not None and train_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
|
||||
if val_log_scale_sq is not None and val_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
|
||||
|
||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
@@ -533,6 +466,7 @@ class Trainer:
|
||||
"epoch": epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"head_state_dict": self.head.state_dict(),
|
||||
"criterion_state_dict": self.criterion.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}, self.best_path)
|
||||
|
||||
Reference in New Issue
Block a user