Files
DeepHealth/train.py

475 lines
19 KiB
Python

import json
import os
import time
import argparse
import math
from dataclasses import asdict, dataclass
from typing import Literal, Sequence
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi
from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt
@dataclass
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull'] = 'weibull'
age_encoder: Literal['sinusoidal', 'learned'] = 'learned'
full_cov: bool = False
n_embd: int = 120
n_head: int = 12
n_layer: int = 12
pdrop: float = 0.1
lambda_reg: float = 1e-4
# SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
# Data Parameters
data_prefix: str = "ukb"
train_ratio: float = 0.7
val_ratio: float = 0.15
random_seed: int = 42
# Training Parameters
batch_size: int = 128
max_epochs: int = 200
warmup_epochs: int = 10
patience: int = 10
min_lr: float = 1e-5
max_lr: float = 5e-4
grad_clip: float = 1.0
weight_decay: float = 1e-2
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
# EMA parameters
ema_decay: float = 0.999
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train Delphi Model")
parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument("--loss_type", type=str, choices=[
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120,
help="Embedding dimension.")
parser.add_argument("--n_head", type=int, default=12,
help="Number of attention heads.")
parser.add_argument("--n_layer", type=int, default=12,
help="Number of transformer layers.")
parser.add_argument("--pdrop", type=float, default=0.1,
help="Dropout probability.")
parser.add_argument("--lambda_reg", type=float,
default=1e-4, help="Regularization weight.")
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
help="Path to pretrained embeddings for SapDelphi.")
parser.add_argument("--data_prefix", type=str,
default="ukb", help="Prefix for dataset files.")
parser.add_argument("--full_cov", action='store_true',
help="Whether to use full covariates.")
parser.add_argument("--train_ratio", type=float,
default=0.7, help="Training data ratio.")
parser.add_argument("--val_ratio", type=float,
default=0.15, help="Validation data ratio.")
parser.add_argument("--random_seed", type=int, default=42,
help="Random seed for data splitting.")
parser.add_argument("--batch_size", type=int,
default=128, help="Batch size.")
parser.add_argument("--max_epochs", type=int, default=200,
help="Maximum number of epochs.")
parser.add_argument("--warmup_epochs", type=int,
default=10, help="Number of warmup epochs.")
parser.add_argument("--patience", type=int, default=10,
help="Early stopping patience.")
parser.add_argument("--min_lr", type=float, default=1e-5,
help="Minimum learning rate.")
parser.add_argument("--max_lr", type=float, default=5e-4,
help="Maximum learning rate.")
parser.add_argument("--grad_clip", type=float,
default=1.0, help="Gradient clipping value.")
parser.add_argument("--weight_decay", type=float,
default=1e-2, help="Weight decay for optimizer.")
parser.add_argument("--ema_decay", type=float,
default=0.999, help="EMA decay rate.")
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help="Device to use for training.")
args = parser.parse_args()
return TrainConfig(**vars(args))
def get_num_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Trainer:
def __init__(
self,
cfg: TrainConfig,
):
self.cfg = cfg
self.device = cfg.device
self.global_step = 0
if cfg.full_cov:
cov_list = None
else:
cov_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=cfg.data_prefix,
covariate_list=cov_list,
)
print("Dataset loaded.")
n_total = len(dataset)
print(f"Total samples in dataset: {n_total}")
print(f"Number of diseases: {dataset.n_disease}")
print(f"Number of continuous covariates: {dataset.n_cont}")
print(f"Number of categorical covariates: {dataset.n_cate}")
self.train_data, self.val_data, _ = random_split(
dataset,
[
int(n_total * cfg.train_ratio),
int(n_total * cfg.val_ratio),
n_total - int(n_total * cfg.train_ratio) -
int(n_total * cfg.val_ratio),
],
generator=torch.Generator().manual_seed(cfg.random_seed),
)
self.train_loader = DataLoader(
self.train_data,
batch_size=cfg.batch_size,
shuffle=True,
collate_fn=health_collate_fn,
)
self.val_loader = DataLoader(
self.val_data,
batch_size=cfg.batch_size,
shuffle=False,
collate_fn=health_collate_fn,
)
if cfg.loss_type == "exponential":
self.criterion = ExponentialNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 1
elif cfg.loss_type == "weibull":
self.criterion = WeibullNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 2
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
if cfg.model_type == "delphi_fork":
self.model = DelphiFork(
n_disease=dataset.n_disease,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.model = SapDelphi(
n_disease=dataset.n_disease,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True,
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
print(f"Model initialized: {cfg.model_type}")
print(f"Number of trainable parameters: {get_num_params(self.model)}")
# Initialize EMA model
self.ema_model = None
if cfg.ema_decay < 1.0:
if cfg.model_type == "delphi_fork":
self.ema_model = DelphiFork(
n_disease=dataset.n_disease,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.ema_model = SapDelphi(
n_disease=dataset.n_disease,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True,
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
self.ema_model.load_state_dict(self.model.state_dict())
for param in self.ema_model.parameters():
param.requires_grad = False
print("EMA model initialized.")
self.optimizer = AdamW(
self.model.parameters(),
lr=cfg.max_lr,
weight_decay=cfg.weight_decay,
betas=(0.9, 0.99),
)
self.total_steps = (len(self.train_loader) *
cfg.max_epochs)
print(f"Total optimization steps: {self.total_steps}")
while True:
cov_suffix = "fullcov" if cfg.full_cov else "partcov"
name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}"
timestamp = time.strftime("%Y%m%d-%H%M%S")
model_dir = os.path.join("runs", f"{name}_{timestamp}")
if not os.path.exists(model_dir):
self.out_dir = model_dir
os.makedirs(model_dir)
break
time.sleep(1)
print(f"Output directory: {self.out_dir}")
self.best_path = os.path.join(self.out_dir, "best_model.pt")
self.global_step = 0
self.save_config()
def save_config(self):
cfg_path = os.path.join(self.out_dir, "train_config.json")
with open(cfg_path, 'w') as f:
json.dump(asdict(self.cfg), f, indent=4)
print(f"Configuration saved to {cfg_path}")
def update_ema(self):
if self.ema_model is None:
return
decay = self.cfg.ema_decay
with torch.no_grad():
model_params = dict(self.model.named_parameters())
ema_params = dict(self.ema_model.named_parameters())
for name in model_params.keys():
ema_params[name].data.mul_(decay).add_(
model_params[name].data, alpha=1 - decay)
def compute_lr(self, current_step: int) -> float:
cfg = self.cfg
if current_step < cfg.warmup_epochs * len(self.train_loader):
lr = cfg.max_lr * (current_step /
(cfg.warmup_epochs * len(self.train_loader)))
else:
denom = (cfg.max_epochs - cfg.warmup_epochs) * \
len(self.train_loader)
progress = (current_step - cfg.warmup_epochs *
len(self.train_loader)) / denom
lr = cfg.min_lr + 0.5 * \
(cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress))
return lr
def train(self) -> None:
history = []
best_val_score = float('inf')
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train()
running_nll = 0.0
running_reg = 0.0
pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
for batch in pbar:
(
event_seq,
time_seq,
cont_feats,
cate_feats,
sexes,
) = batch
event_seq = event_seq.to(self.device)
time_seq = time_seq.to(self.device)
cont_feats = cont_feats.to(self.device)
cate_feats = cate_feats.to(self.device)
sexes = sexes.to(self.device)
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
if res is None:
continue
dt, b_prev, t_prev, b_next, t_next = res
self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
logits = self.model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
b_prev=b_prev,
t_prev=t_prev,
)
target_event = event_seq[b_next, t_next] - 2
nll_vec, reg = self.criterion(
logits,
target_event,
dt,
reduction="none",
)
finite_mask = torch.isfinite(nll_vec)
if not finite_mask.any():
continue
nll_vec = nll_vec[finite_mask]
nll = nll_vec.mean()
loss = nll + reg
batch_count += 1
running_nll += nll.item()
running_reg += reg.item()
pbar.set_postfix({
"lr": lr,
"NLL": running_nll / batch_count,
"Reg": running_reg / batch_count,
})
loss.backward()
if self.cfg.grad_clip > 0:
clip_grad_norm_(self.model.parameters(),
self.cfg.grad_clip)
self.optimizer.step()
self.update_ema()
self.global_step += 1
if batch_count == 0:
print("No valid batches in this epoch, skipping validation.")
continue
train_nll = running_nll / batch_count
train_reg = running_reg / batch_count
self.ema_model.eval()
total_val_pairs = 0
total_val_nll = 0.0
total_val_reg = 0.0
with torch.no_grad():
val_pbar = tqdm(self.val_loader, desc="Validation")
for batch in val_pbar:
(
event_seq,
time_seq,
cont_feats,
cate_feats,
sexes,
) = batch
event_seq = event_seq.to(self.device)
time_seq = time_seq.to(self.device)
cont_feats = cont_feats.to(self.device)
cate_feats = cate_feats.to(self.device)
sexes = sexes.to(self.device)
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
if res is None:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
logits = self.ema_model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
b_prev=b_prev,
t_prev=t_prev
)
target_events = event_seq[b_next, t_next] - 2
nll, reg = self.criterion(
logits,
target_events,
dt,
reduction="none",
)
batch_nll_sum = nll.sum().item()
total_val_nll += batch_nll_sum
total_val_reg += reg.item() * num_pairs
total_val_pairs += num_pairs
current_val_avg_nll = total_val_nll / \
total_val_pairs if total_val_pairs > 0 else 0.0
current_val_avg_reg = total_val_reg / \
total_val_pairs if total_val_pairs > 0 else 0.0
val_pbar.set_postfix({
"NLL": f"{current_val_avg_nll:.4f}",
"Reg": f"{current_val_avg_reg:.4f}",
})
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
history.append({
"epoch": epoch,
"train_nll": train_nll,
"train_reg": train_reg,
"val_nll": val_nll,
"val_reg": val_reg,
})
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
tqdm.write(f" Train NLL: {train_nll:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4)
# Check for improvement
if val_nll < best_val_score:
best_val_score = val_nll
patience_counter = 0
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.ema_model.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path)
else:
patience_counter += 1
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
tqdm.write(
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
return
tqdm.write(
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
tqdm.write("\n🎉 Training complete!")
if __name__ == "__main__":
cfg = parse_args()
trainer = Trainer(cfg)
trainer.train()