Files
DeepHealth/train.py

235 lines
8.7 KiB
Python

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 10
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
print("Starting training...")
for epoch in range(config.max_epoch):
# --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Phase ---
model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
train_steps += 1
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
# --- Validation Phase ---
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
# --- Early Stopping Check ---
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
# --- Save Best Model at the End ---
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt")
torch.save(model.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
if __name__ == '__main__':
main()