Files
DeepHealth/train_multigpu.py
Jiarui Li d760c45baf feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
2025-10-17 14:09:34 +08:00

274 lines
11 KiB
Python

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10
early_stopping_patience = 10
# Loss parameters
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# System parameters (will be updated by DDP setup)
device = 'cuda'
# --- DDP Setup ---
def setup_ddp():
"""Initializes the distributed data parallel environment."""
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return rank, local_rank
def cleanup_ddp():
"""Cleans up the distributed data parallel environment."""
dist.destroy_process_group()
# --- Main Training Script ---
def main():
rank, local_rank = setup_ddp()
is_main_process = (rank == 0)
config = TrainConfig()
config.device = f'cuda:{local_rank}'
if is_main_process:
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
if is_main_process:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main_process:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization ---
if is_main_process:
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
model = DDP(model, device_ids=[local_rank])
if is_main_process:
print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main_process:
print("Starting training...")
for epoch in range(config.max_epoch):
train_sampler.set_epoch(epoch) # Important for shuffling
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]", disable=not is_main_process)
for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events, input_times = event_seq[:, :-1], time_seq[:, :-1]
target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
train_steps += 1
if is_main_process:
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]", disable=not is_main_process)
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events, input_times = event_seq[:, :-1], time_seq[:, :-1]
target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
if is_main_process:
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
if is_main_process:
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.module.state_dict(), checkpoint_filename)
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
if patience_counter >= config.early_stopping_patience:
if is_main_process:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
if is_main_process:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# Load the best weights into the module before saving the final model
model.module.load_state_dict(torch.load(checkpoint_filename))
print(f"Saving final best model to {model_filename}")
torch.save(model.module.state_dict(), model_filename)
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
cleanup_ddp()
if __name__ == '__main__':
main()