Files
DeepHealth/train.py

564 lines
23 KiB
Python

from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
import torch
import json
import os
import time
import argparse
import math
from dataclasses import asdict, dataclass, field
from typing import Literal, Optional, Sequence
@dataclass
class TrainConfig:
# Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'discrete_time_cif',
'lognormal_basis_binned_hazard_cif'] = 'exponential'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False
n_embd: int = 120
n_head: int = 12
n_layer: int = 12
pdrop: float = 0.1
lambda_reg: float = 1e-4
bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
)
# LogNormal basis (shared by Route-3 binned hazard)
lognormal_centers: Optional[Sequence[float]] = field(
default_factory=list) # mu_r in log-time
loss_eps: float = 1e-8
bandwidth_init: float = 0.7
bandwidth_min: float = 1e-3
bandwidth_max: float = 10.0
lambda_sigma_reg: float = 1e-4
sigma_reg_target: Optional[float] = None
rank: int = 16
# SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
# Data Parameters
data_prefix: str = "ukb"
train_ratio: float = 0.7
val_ratio: float = 0.15
random_seed: int = 42
# Training Parameters
batch_size: int = 128
max_epochs: int = 200
warmup_epochs: int = 10
patience: int = 10
min_lr: float = 1e-5
max_lr: float = 5e-4
grad_clip: float = 1.0
weight_decay: float = 1e-2
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
num_workers: int = 0
prefetch_factor: int = 2
persistent_workers: bool = False
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train Delphi Model")
parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument(
"--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif',
'lognormal_basis_binned_hazard_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120,
help="Embedding dimension.")
parser.add_argument("--n_head", type=int, default=12,
help="Number of attention heads.")
parser.add_argument("--n_layer", type=int, default=12,
help="Number of transformer layers.")
parser.add_argument("--pdrop", type=float, default=0.1,
help="Dropout probability.")
parser.add_argument("--lambda_reg", type=float,
default=1e-4, help="Regularization weight.")
parser.add_argument(
"--lognormal_centers",
type=float,
nargs='*',
default=None,
help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
parser.add_argument("--loss_eps", type=float, default=1e-8,
help="Epsilon for log clamps in lognormal-basis losses.")
parser.add_argument("--bandwidth_init", type=float, default=0.7,
help="Initial sigma for lognormal-basis.")
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
help="Minimum sigma clamp for lognormal-basis.")
parser.add_argument("--bandwidth_max", type=float, default=10.0,
help="Maximum sigma clamp for lognormal-basis.")
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
help="Sigma regularization strength for lognormal-basis.")
parser.add_argument("--sigma_reg_target", type=float, default=None,
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
parser.add_argument("--rank", type=int, default=16,
help="Rank for low-rank parameterization (if applicable).")
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
help="Path to pretrained embeddings for SapDelphi.")
parser.add_argument("--data_prefix", type=str,
default="ukb", help="Prefix for dataset files.")
parser.add_argument("--full_cov", action='store_true',
help="Whether to use full covariates.")
parser.add_argument("--train_ratio", type=float,
default=0.7, help="Training data ratio.")
parser.add_argument("--val_ratio", type=float,
default=0.15, help="Validation data ratio.")
parser.add_argument("--random_seed", type=int, default=42,
help="Random seed for data splitting.")
parser.add_argument("--batch_size", type=int,
default=128, help="Batch size.")
parser.add_argument("--max_epochs", type=int, default=200,
help="Maximum number of epochs.")
parser.add_argument("--warmup_epochs", type=int,
default=10, help="Number of warmup epochs.")
parser.add_argument("--patience", type=int, default=10,
help="Early stopping patience.")
parser.add_argument("--min_lr", type=float, default=1e-5,
help="Minimum learning rate.")
parser.add_argument("--max_lr", type=float, default=5e-4,
help="Maximum learning rate.")
parser.add_argument("--grad_clip", type=float,
default=1.0, help="Gradient clipping value.")
parser.add_argument("--weight_decay", type=float,
default=1e-2, help="Weight decay for optimizer.")
parser.add_argument("--num_workers", type=int, default=0,
help="DataLoader workers (0 is safest on Windows).")
parser.add_argument("--prefetch_factor", type=int, default=2,
help="DataLoader prefetch factor (only used when num_workers>0).")
parser.add_argument("--persistent_workers", action='store_true',
help="Keep DataLoader workers alive between epochs (only if num_workers>0).")
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help="Device to use for training.")
args = parser.parse_args()
cfg = TrainConfig(**vars(args))
# If lognormal_centers are not provided, derive a reasonable default from bin_edges.
if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0:
edges = [float(x) for x in cfg.bin_edges]
finite = [e for e in edges if math.isfinite(e)]
if len(finite) < 2:
raise ValueError(
"bin_edges must contain at least two finite edges to derive lognormal_centers")
e1 = finite[1]
t_min = (e1 * 1e-3) if e1 > 0 else 1e-12
# Build one center per bin (including the +inf last bin if present).
centers: list[float] = []
for i in range(1, len(edges)):
left = float(edges[i - 1])
right = float(edges[i])
if i == 1 and left <= 0.0:
left_pos = t_min
else:
left_pos = max(left, t_min)
if math.isinf(right):
mid = max(left_pos * 2.0, left_pos + 1e-6)
else:
right_pos = max(right, t_min)
mid = math.sqrt(left_pos * right_pos)
centers.append(math.log(max(mid, t_min)))
cfg.lognormal_centers = centers
return cfg
def get_num_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Trainer:
def __init__(
self,
cfg: TrainConfig,
):
self.cfg = cfg
self.device = cfg.device
self.global_step = 0
use_cuda = str(self.device).startswith(
"cuda") and torch.cuda.is_available()
if use_cuda:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
if cfg.full_cov:
cov_list = None
else:
cov_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=cfg.data_prefix,
covariate_list=cov_list,
)
print("Dataset loaded.")
n_total = len(dataset)
print(f"Total samples in dataset: {n_total}")
print(f"Number of diseases: {dataset.n_disease}")
print(f"Number of continuous covariates: {dataset.n_cont}")
print(f"Number of categorical covariates: {dataset.n_cate}")
self.train_data, self.val_data, _ = random_split(
dataset,
[
int(n_total * cfg.train_ratio),
int(n_total * cfg.val_ratio),
n_total - int(n_total * cfg.train_ratio) -
int(n_total * cfg.val_ratio),
],
generator=torch.Generator().manual_seed(cfg.random_seed),
)
pin_memory = use_cuda
loader_kwargs = dict(
collate_fn=health_collate_fn,
pin_memory=pin_memory,
)
if cfg.num_workers > 0:
loader_kwargs["num_workers"] = cfg.num_workers
loader_kwargs["prefetch_factor"] = cfg.prefetch_factor
loader_kwargs["persistent_workers"] = cfg.persistent_workers
self.train_loader = DataLoader(
self.train_data,
batch_size=cfg.batch_size,
shuffle=True,
**loader_kwargs,
)
self.val_loader = DataLoader(
self.val_data,
batch_size=cfg.batch_size,
shuffle=False,
**loader_kwargs,
)
if cfg.loss_type == "exponential":
self.criterion = ExponentialNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
out_dims = [dataset.n_disease]
elif cfg.loss_type == "discrete_time_cif":
self.criterion = DiscreteTimeCIFNLLLoss(
bin_edges=cfg.bin_edges,
lambda_reg=cfg.lambda_reg,
).to(self.device)
# logits shape (M, K+1, n_bins+1)
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
elif cfg.loss_type == "lognormal_basis_binned_hazard_cif":
r = len(cfg.lognormal_centers)
if r <= 0:
raise ValueError(
"lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif")
self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss(
bin_edges=cfg.bin_edges,
centers=cfg.lognormal_centers,
eps=cfg.loss_eps,
bandwidth_init=cfg.bandwidth_init,
bandwidth_min=cfg.bandwidth_min,
bandwidth_max=cfg.bandwidth_max,
lambda_sigma_reg=cfg.lambda_sigma_reg,
sigma_reg_target=cfg.sigma_reg_target,
lambda_reg=cfg.lambda_reg,
).to(self.device)
# Head emits (M, J, R) for Route-3.
out_dims = [dataset.n_disease, r]
else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
if cfg.model_type == "delphi_fork":
self.model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_weights_path=cfg.pretrained_emd_path,
freeze_embeddings=True,
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
# Prediction head maps context vectors -> logits with the shape required by the loss.
self.head = SimpleHead(
n_embd=cfg.n_embd,
out_dims=out_dims,
).to(self.device)
print(f"Model initialized: {cfg.model_type}")
print(
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
print(
f"Number of trainable parameters (head): {get_num_params(self.head)}")
self._optim_params = (
list(self.model.parameters())
+ list(self.head.parameters())
)
self.optimizer = AdamW(
self._optim_params,
lr=cfg.max_lr,
weight_decay=cfg.weight_decay,
betas=(0.9, 0.99),
)
self.total_steps = (len(self.train_loader) *
cfg.max_epochs)
print(f"Total optimization steps: {self.total_steps}")
while True:
cov_suffix = "fullcov" if cfg.full_cov else "partcov"
name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}"
timestamp = time.strftime("%Y%m%d-%H%M%S")
model_dir = os.path.join("runs", f"{name}_{timestamp}")
if not os.path.exists(model_dir):
self.out_dir = model_dir
os.makedirs(model_dir)
break
time.sleep(1)
print(f"Output directory: {self.out_dir}")
self.best_path = os.path.join(self.out_dir, "best_model.pt")
self.global_step = 0
self.save_config()
def save_config(self):
cfg_path = os.path.join(self.out_dir, "train_config.json")
with open(cfg_path, 'w') as f:
json.dump(asdict(self.cfg), f, indent=4)
print(f"Configuration saved to {cfg_path}")
def compute_lr(self, current_step: int) -> float:
cfg = self.cfg
if current_step < cfg.warmup_epochs * len(self.train_loader):
lr = cfg.max_lr * (current_step /
(cfg.warmup_epochs * len(self.train_loader)))
else:
denom = (cfg.max_epochs - cfg.warmup_epochs) * \
len(self.train_loader)
progress = (current_step - cfg.warmup_epochs *
len(self.train_loader)) / denom
lr = cfg.min_lr + 0.5 * \
(cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress))
return lr
def train(self) -> None:
history = []
best_val_score = float('inf')
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train()
self.head.train()
total_train_pairs = 0
total_train_nll = 0.0
total_train_reg = 0.0
pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
for batch in pbar:
(
event_seq,
time_seq,
cont_feats,
cate_feats,
sexes,
) = batch
event_seq = event_seq.to(self.device, non_blocking=True)
time_seq = time_seq.to(self.device, non_blocking=True)
cont_feats = cont_feats.to(self.device, non_blocking=True)
cate_feats = cate_feats.to(self.device, non_blocking=True)
sexes = sexes.to(self.device, non_blocking=True)
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
if res is None:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step)
self.optimizer.param_groups[0]['lr'] = lr
h = self.model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
)
# Context vectors for selected previous events
c = h[b_prev, t_prev] # (M, D)
logits = self.head(c)
target_event = event_seq[b_next, t_next] - 2
nll_vec, reg = self.criterion(
logits,
target_event,
dt,
reduction="none",
)
nll = nll_vec.mean()
loss = nll + reg
batch_count += 1
total_train_pairs += num_pairs
total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({
"lr": lr,
"NLL": avg_train_nll,
"Reg": avg_train_reg,
})
loss.backward()
if self.cfg.grad_clip > 0:
clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
self.optimizer.step()
self.global_step += 1
if batch_count == 0:
print("No valid batches in this epoch, skipping validation.")
continue
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
self.model.eval()
self.head.eval()
total_val_pairs = 0
total_val_nll = 0.0
total_val_reg = 0.0
with torch.no_grad():
val_pbar = tqdm(self.val_loader, desc="Validation")
for batch in val_pbar:
(
event_seq,
time_seq,
cont_feats,
cate_feats,
sexes,
) = batch
event_seq = event_seq.to(self.device, non_blocking=True)
time_seq = time_seq.to(self.device, non_blocking=True)
cont_feats = cont_feats.to(self.device, non_blocking=True)
cate_feats = cate_feats.to(self.device, non_blocking=True)
sexes = sexes.to(self.device, non_blocking=True)
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
if res is None:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
h = self.model(
event_seq,
time_seq,
sexes,
cont_feats,
cate_feats,
)
c = h[b_prev, t_prev]
logits = self.head(c)
target_events = event_seq[b_next, t_next] - 2
nll, reg = self.criterion(
logits,
target_events,
dt,
reduction="none",
)
batch_nll_sum = nll.sum().item()
total_val_nll += batch_nll_sum
total_val_reg += reg.item() * num_pairs
total_val_pairs += num_pairs
current_val_avg_nll = total_val_nll / \
total_val_pairs if total_val_pairs > 0 else 0.0
current_val_avg_reg = total_val_reg / \
total_val_pairs if total_val_pairs > 0 else 0.0
val_pbar.set_postfix({
"NLL": f"{current_val_avg_nll:.4f}",
"Reg": f"{current_val_avg_reg:.4f}",
})
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
history.append({
"epoch": epoch,
"train_nll": train_nll,
"train_reg": train_reg,
"val_nll": val_nll,
"val_reg": val_reg,
})
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
tqdm.write(f" Train NLL: {train_nll:.4f}")
tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4)
# Check for improvement
if val_nll < best_val_score:
best_val_score = val_nll
patience_counter = 0
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"head_state_dict": self.head.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path)
else:
patience_counter += 1
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
tqdm.write(
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
return
tqdm.write(
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
tqdm.write("\n🎉 Training complete!")
if __name__ == "__main__":
cfg = parse_args()
trainer = Trainer(cfg)
trainer.train()