2026-01-07 23:57:29 +08:00
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import argparse
|
|
|
|
|
import math
|
|
|
|
|
from dataclasses import asdict, dataclass
|
|
|
|
|
from typing import Literal, Sequence
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.optim import AdamW
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from torch.utils.data import random_split
|
|
|
|
|
from torch.nn.utils import clip_grad_norm_
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from dataset import HealthDataset, health_collate_fn
|
|
|
|
|
from model import DelphiFork, SapDelphi
|
|
|
|
|
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TrainConfig:
|
|
|
|
|
# Model Parameters
|
|
|
|
|
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
|
|
|
|
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
2026-01-08 11:38:45 +08:00
|
|
|
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
2026-01-07 23:57:29 +08:00
|
|
|
full_cov: bool = False
|
|
|
|
|
n_embd: int = 120
|
|
|
|
|
n_head: int = 12
|
|
|
|
|
n_layer: int = 12
|
|
|
|
|
pdrop: float = 0.1
|
|
|
|
|
lambda_reg: float = 1e-4
|
|
|
|
|
# SapDelphi specific
|
|
|
|
|
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
|
|
|
|
# Data Parameters
|
|
|
|
|
data_prefix: str = "ukb"
|
|
|
|
|
train_ratio: float = 0.7
|
|
|
|
|
val_ratio: float = 0.15
|
|
|
|
|
random_seed: int = 42
|
|
|
|
|
# Training Parameters
|
|
|
|
|
batch_size: int = 128
|
|
|
|
|
max_epochs: int = 200
|
|
|
|
|
warmup_epochs: int = 10
|
|
|
|
|
patience: int = 10
|
|
|
|
|
min_lr: float = 1e-5
|
|
|
|
|
max_lr: float = 5e-4
|
|
|
|
|
grad_clip: float = 1.0
|
|
|
|
|
weight_decay: float = 1e-2
|
|
|
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
# EMA parameters
|
|
|
|
|
ema_decay: float = 0.999
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> TrainConfig:
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train Delphi Model")
|
|
|
|
|
parser.add_argument("--model_type", type=str, choices=[
|
|
|
|
|
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
|
|
|
|
parser.add_argument("--loss_type", type=str, choices=[
|
|
|
|
|
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
|
|
|
|
parser.add_argument("--age_encoder", type=str, choices=[
|
2026-01-08 11:38:45 +08:00
|
|
|
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
2026-01-07 23:57:29 +08:00
|
|
|
parser.add_argument("--n_embd", type=int, default=120,
|
|
|
|
|
help="Embedding dimension.")
|
|
|
|
|
parser.add_argument("--n_head", type=int, default=12,
|
|
|
|
|
help="Number of attention heads.")
|
|
|
|
|
parser.add_argument("--n_layer", type=int, default=12,
|
|
|
|
|
help="Number of transformer layers.")
|
|
|
|
|
parser.add_argument("--pdrop", type=float, default=0.1,
|
|
|
|
|
help="Dropout probability.")
|
|
|
|
|
parser.add_argument("--lambda_reg", type=float,
|
|
|
|
|
default=1e-4, help="Regularization weight.")
|
|
|
|
|
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
|
|
|
|
|
help="Path to pretrained embeddings for SapDelphi.")
|
|
|
|
|
parser.add_argument("--data_prefix", type=str,
|
|
|
|
|
default="ukb", help="Prefix for dataset files.")
|
|
|
|
|
parser.add_argument("--full_cov", action='store_true',
|
|
|
|
|
help="Whether to use full covariates.")
|
|
|
|
|
parser.add_argument("--train_ratio", type=float,
|
|
|
|
|
default=0.7, help="Training data ratio.")
|
|
|
|
|
parser.add_argument("--val_ratio", type=float,
|
|
|
|
|
default=0.15, help="Validation data ratio.")
|
|
|
|
|
parser.add_argument("--random_seed", type=int, default=42,
|
|
|
|
|
help="Random seed for data splitting.")
|
|
|
|
|
parser.add_argument("--batch_size", type=int,
|
|
|
|
|
default=128, help="Batch size.")
|
|
|
|
|
parser.add_argument("--max_epochs", type=int, default=200,
|
|
|
|
|
help="Maximum number of epochs.")
|
|
|
|
|
parser.add_argument("--warmup_epochs", type=int,
|
|
|
|
|
default=10, help="Number of warmup epochs.")
|
|
|
|
|
parser.add_argument("--patience", type=int, default=10,
|
|
|
|
|
help="Early stopping patience.")
|
|
|
|
|
parser.add_argument("--min_lr", type=float, default=1e-5,
|
|
|
|
|
help="Minimum learning rate.")
|
|
|
|
|
parser.add_argument("--max_lr", type=float, default=5e-4,
|
|
|
|
|
help="Maximum learning rate.")
|
|
|
|
|
parser.add_argument("--grad_clip", type=float,
|
|
|
|
|
default=1.0, help="Gradient clipping value.")
|
|
|
|
|
parser.add_argument("--weight_decay", type=float,
|
|
|
|
|
default=1e-2, help="Weight decay for optimizer.")
|
|
|
|
|
parser.add_argument("--ema_decay", type=float,
|
|
|
|
|
default=0.999, help="EMA decay rate.")
|
|
|
|
|
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
|
|
help="Device to use for training.")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
return TrainConfig(**vars(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_params(model: nn.Module) -> int:
|
|
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
cfg: TrainConfig,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
self.device = cfg.device
|
|
|
|
|
self.global_step = 0
|
|
|
|
|
|
|
|
|
|
if cfg.full_cov:
|
|
|
|
|
cov_list = None
|
|
|
|
|
else:
|
|
|
|
|
cov_list = ["bmi", "smoking", "alcohol"]
|
|
|
|
|
dataset = HealthDataset(
|
|
|
|
|
data_prefix=cfg.data_prefix,
|
|
|
|
|
covariate_list=cov_list,
|
|
|
|
|
)
|
|
|
|
|
print("Dataset loaded.")
|
|
|
|
|
n_total = len(dataset)
|
|
|
|
|
print(f"Total samples in dataset: {n_total}")
|
|
|
|
|
print(f"Number of diseases: {dataset.n_disease}")
|
|
|
|
|
print(f"Number of continuous covariates: {dataset.n_cont}")
|
|
|
|
|
print(f"Number of categorical covariates: {dataset.n_cate}")
|
|
|
|
|
self.train_data, self.val_data, _ = random_split(
|
|
|
|
|
dataset,
|
|
|
|
|
[
|
|
|
|
|
int(n_total * cfg.train_ratio),
|
|
|
|
|
int(n_total * cfg.val_ratio),
|
|
|
|
|
n_total - int(n_total * cfg.train_ratio) -
|
|
|
|
|
int(n_total * cfg.val_ratio),
|
|
|
|
|
],
|
|
|
|
|
generator=torch.Generator().manual_seed(cfg.random_seed),
|
|
|
|
|
)
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
self.train_data,
|
|
|
|
|
batch_size=cfg.batch_size,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
collate_fn=health_collate_fn,
|
|
|
|
|
)
|
|
|
|
|
self.val_loader = DataLoader(
|
|
|
|
|
self.val_data,
|
|
|
|
|
batch_size=cfg.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
collate_fn=health_collate_fn,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cfg.loss_type == "exponential":
|
|
|
|
|
self.criterion = ExponentialNLLLoss(
|
|
|
|
|
lambda_reg=cfg.lambda_reg,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
n_dim = 1
|
|
|
|
|
elif cfg.loss_type == "weibull":
|
|
|
|
|
self.criterion = WeibullNLLLoss(
|
|
|
|
|
lambda_reg=cfg.lambda_reg,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
n_dim = 2
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
|
|
|
|
|
|
|
|
|
if cfg.model_type == "delphi_fork":
|
|
|
|
|
self.model = DelphiFork(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_dim=n_dim,
|
|
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-07 23:57:29 +08:00
|
|
|
).to(self.device)
|
|
|
|
|
elif cfg.model_type == "sap_delphi":
|
|
|
|
|
self.model = SapDelphi(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_dim=n_dim,
|
|
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-07 23:57:29 +08:00
|
|
|
pretrained_emd_path=cfg.pretrained_emd_path,
|
|
|
|
|
freeze_pretrained_emd=True,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
|
|
|
|
print(f"Model initialized: {cfg.model_type}")
|
|
|
|
|
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
|
|
|
|
|
|
|
|
|
# Initialize EMA model
|
|
|
|
|
self.ema_model = None
|
|
|
|
|
if cfg.ema_decay < 1.0:
|
|
|
|
|
if cfg.model_type == "delphi_fork":
|
|
|
|
|
self.ema_model = DelphiFork(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_dim=n_dim,
|
|
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-07 23:57:29 +08:00
|
|
|
).to(self.device)
|
|
|
|
|
elif cfg.model_type == "sap_delphi":
|
|
|
|
|
self.ema_model = SapDelphi(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_dim=n_dim,
|
|
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-07 23:57:29 +08:00
|
|
|
pretrained_emd_path=cfg.pretrained_emd_path,
|
|
|
|
|
freeze_pretrained_emd=True,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
|
|
|
|
self.ema_model.load_state_dict(self.model.state_dict())
|
|
|
|
|
for param in self.ema_model.parameters():
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
print("EMA model initialized.")
|
|
|
|
|
self.optimizer = AdamW(
|
|
|
|
|
self.model.parameters(),
|
|
|
|
|
lr=cfg.max_lr,
|
|
|
|
|
weight_decay=cfg.weight_decay,
|
|
|
|
|
betas=(0.9, 0.99),
|
|
|
|
|
)
|
|
|
|
|
self.total_steps = (len(self.train_loader) *
|
|
|
|
|
cfg.max_epochs)
|
|
|
|
|
print(f"Total optimization steps: {self.total_steps}")
|
|
|
|
|
while True:
|
|
|
|
|
cov_suffix = "fullcov" if cfg.full_cov else "partcov"
|
|
|
|
|
name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}"
|
|
|
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
|
|
|
model_dir = os.path.join("runs", f"{name}_{timestamp}")
|
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
|
|
self.out_dir = model_dir
|
|
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
break
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
print(f"Output directory: {self.out_dir}")
|
|
|
|
|
self.best_path = os.path.join(self.out_dir, "best_model.pt")
|
|
|
|
|
self.global_step = 0
|
|
|
|
|
self.save_config()
|
|
|
|
|
|
|
|
|
|
def save_config(self):
|
|
|
|
|
cfg_path = os.path.join(self.out_dir, "train_config.json")
|
|
|
|
|
with open(cfg_path, 'w') as f:
|
|
|
|
|
json.dump(asdict(self.cfg), f, indent=4)
|
|
|
|
|
print(f"Configuration saved to {cfg_path}")
|
|
|
|
|
|
|
|
|
|
def update_ema(self):
|
|
|
|
|
if self.ema_model is None:
|
|
|
|
|
return
|
|
|
|
|
decay = self.cfg.ema_decay
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
model_params = dict(self.model.named_parameters())
|
|
|
|
|
ema_params = dict(self.ema_model.named_parameters())
|
|
|
|
|
for name in model_params.keys():
|
|
|
|
|
ema_params[name].data.mul_(decay).add_(
|
|
|
|
|
model_params[name].data, alpha=1 - decay)
|
|
|
|
|
|
|
|
|
|
def compute_lr(self, current_step: int) -> float:
|
|
|
|
|
cfg = self.cfg
|
|
|
|
|
if current_step < cfg.warmup_epochs * len(self.train_loader):
|
|
|
|
|
lr = cfg.max_lr * (current_step /
|
|
|
|
|
(cfg.warmup_epochs * len(self.train_loader)))
|
|
|
|
|
else:
|
|
|
|
|
denom = (cfg.max_epochs - cfg.warmup_epochs) * \
|
|
|
|
|
len(self.train_loader)
|
|
|
|
|
progress = (current_step - cfg.warmup_epochs *
|
|
|
|
|
len(self.train_loader)) / denom
|
|
|
|
|
lr = cfg.min_lr + 0.5 * \
|
|
|
|
|
(cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress))
|
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
def train(self) -> None:
|
|
|
|
|
history = []
|
2026-01-08 00:07:15 +08:00
|
|
|
best_val_score = float('inf')
|
2026-01-07 23:57:29 +08:00
|
|
|
patience_counter = 0
|
|
|
|
|
for epoch in range(1, self.cfg.max_epochs + 1):
|
|
|
|
|
self.model.train()
|
|
|
|
|
running_nll = 0.0
|
|
|
|
|
running_reg = 0.0
|
|
|
|
|
pbar = tqdm(self.train_loader,
|
|
|
|
|
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
2026-01-08 00:07:15 +08:00
|
|
|
batch_count = 0
|
2026-01-07 23:57:29 +08:00
|
|
|
for batch in pbar:
|
|
|
|
|
(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
sexes,
|
|
|
|
|
) = batch
|
|
|
|
|
event_seq = event_seq.to(self.device)
|
|
|
|
|
time_seq = time_seq.to(self.device)
|
|
|
|
|
cont_feats = cont_feats.to(self.device)
|
|
|
|
|
cate_feats = cate_feats.to(self.device)
|
|
|
|
|
sexes = sexes.to(self.device)
|
|
|
|
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
|
|
|
|
if res is None:
|
|
|
|
|
continue
|
|
|
|
|
dt, b_prev, t_prev, b_next, t_next = res
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
lr = self.compute_lr(self.global_step)
|
|
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
|
|
param_group['lr'] = lr
|
|
|
|
|
logits = self.model(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
sexes,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
b_prev=b_prev,
|
|
|
|
|
t_prev=t_prev,
|
|
|
|
|
)
|
|
|
|
|
target_event = event_seq[b_next, t_next] - 2
|
|
|
|
|
nll_vec, reg = self.criterion(
|
|
|
|
|
logits,
|
|
|
|
|
target_event,
|
|
|
|
|
dt,
|
|
|
|
|
reduction="none",
|
|
|
|
|
)
|
|
|
|
|
finite_mask = torch.isfinite(nll_vec)
|
|
|
|
|
if not finite_mask.any():
|
|
|
|
|
continue
|
|
|
|
|
nll_vec = nll_vec[finite_mask]
|
|
|
|
|
nll = nll_vec.mean()
|
|
|
|
|
|
|
|
|
|
loss = nll + reg
|
|
|
|
|
batch_count += 1
|
|
|
|
|
running_nll += nll.item()
|
|
|
|
|
running_reg += reg.item()
|
|
|
|
|
pbar.set_postfix({
|
|
|
|
|
"lr": lr,
|
|
|
|
|
"NLL": running_nll / batch_count,
|
|
|
|
|
"Reg": running_reg / batch_count,
|
|
|
|
|
})
|
|
|
|
|
loss.backward()
|
|
|
|
|
if self.cfg.grad_clip > 0:
|
|
|
|
|
clip_grad_norm_(self.model.parameters(),
|
|
|
|
|
self.cfg.grad_clip)
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
self.update_ema()
|
|
|
|
|
self.global_step += 1
|
|
|
|
|
|
|
|
|
|
if batch_count == 0:
|
|
|
|
|
print("No valid batches in this epoch, skipping validation.")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
train_nll = running_nll / batch_count
|
|
|
|
|
train_reg = running_reg / batch_count
|
|
|
|
|
|
|
|
|
|
self.ema_model.eval()
|
|
|
|
|
total_val_pairs = 0
|
|
|
|
|
total_val_nll = 0.0
|
|
|
|
|
total_val_reg = 0.0
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
val_pbar = tqdm(self.val_loader, desc="Validation")
|
|
|
|
|
for batch in val_pbar:
|
|
|
|
|
(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
sexes,
|
|
|
|
|
) = batch
|
|
|
|
|
event_seq = event_seq.to(self.device)
|
|
|
|
|
time_seq = time_seq.to(self.device)
|
|
|
|
|
cont_feats = cont_feats.to(self.device)
|
|
|
|
|
cate_feats = cate_feats.to(self.device)
|
|
|
|
|
sexes = sexes.to(self.device)
|
|
|
|
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
|
|
|
|
if res is None:
|
|
|
|
|
continue
|
|
|
|
|
dt, b_prev, t_prev, b_next, t_next = res
|
|
|
|
|
num_pairs = dt.size(0)
|
|
|
|
|
logits = self.ema_model(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
sexes,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
b_prev=b_prev,
|
|
|
|
|
t_prev=t_prev
|
|
|
|
|
)
|
|
|
|
|
target_events = event_seq[b_next, t_next] - 2
|
|
|
|
|
nll, reg = self.criterion(
|
|
|
|
|
logits,
|
|
|
|
|
target_events,
|
|
|
|
|
dt,
|
|
|
|
|
reduction="none",
|
|
|
|
|
)
|
|
|
|
|
batch_nll_sum = nll.sum().item()
|
|
|
|
|
total_val_nll += batch_nll_sum
|
|
|
|
|
total_val_reg += reg.item() * num_pairs
|
|
|
|
|
total_val_pairs += num_pairs
|
|
|
|
|
|
|
|
|
|
current_val_avg_nll = total_val_nll / \
|
|
|
|
|
total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
current_val_avg_reg = total_val_reg / \
|
|
|
|
|
total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
val_pbar.set_postfix({
|
|
|
|
|
"NLL": f"{current_val_avg_nll:.4f}",
|
|
|
|
|
"Reg": f"{current_val_avg_reg:.4f}",
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
history.append({
|
|
|
|
|
"epoch": epoch,
|
|
|
|
|
"train_nll": train_nll,
|
|
|
|
|
"train_reg": train_reg,
|
|
|
|
|
"val_nll": val_nll,
|
|
|
|
|
"val_reg": val_reg,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
|
|
|
|
tqdm.write(f" Train NLL: {train_nll:.4f}")
|
|
|
|
|
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
|
|
|
|
json.dump(history, f, indent=4)
|
|
|
|
|
|
|
|
|
|
# Check for improvement
|
|
|
|
|
if val_nll < best_val_score:
|
|
|
|
|
best_val_score = val_nll
|
2026-01-08 00:07:15 +08:00
|
|
|
patience_counter = 0
|
2026-01-07 23:57:29 +08:00
|
|
|
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
|
|
|
|
|
|
|
|
|
|
torch.save({
|
|
|
|
|
"epoch": epoch,
|
|
|
|
|
"global_step": self.global_step,
|
|
|
|
|
"model_state_dict": self.ema_model.state_dict(),
|
|
|
|
|
"criterion_state_dict": self.criterion.state_dict(),
|
|
|
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
|
|
|
}, self.best_path)
|
|
|
|
|
else:
|
2026-01-08 00:07:15 +08:00
|
|
|
patience_counter += 1
|
|
|
|
|
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
|
2026-01-07 23:57:29 +08:00
|
|
|
tqdm.write(
|
2026-01-08 00:07:15 +08:00
|
|
|
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
|
2026-01-07 23:57:29 +08:00
|
|
|
return
|
|
|
|
|
tqdm.write(
|
2026-01-08 00:07:15 +08:00
|
|
|
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
tqdm.write("\n🎉 Training complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
cfg = parse_args()
|
|
|
|
|
trainer = Trainer(cfg)
|
|
|
|
|
trainer.train()
|