feat: Add iteration-based training scripts (single and multi-GPU)

This commit is contained in:
2025-10-18 10:05:37 +08:00
parent a832a45c62
commit 3390bc025e
2 changed files with 465 additions and 0 deletions

218
train_iter.py Normal file
View File

@@ -0,0 +1,218 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_iter = 200000
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
# --- 3. Training Loop ---
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
print("\nTraining finished.")
# --- 4. Final Validation ---
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter.png')
print("\nLoss curves saved to loss_curves_iter.png")
if __name__ == '__main__':
main()

247
train_iter_multigpu.py Normal file
View File

@@ -0,0 +1,247 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_iter = 200000
batch_size = 128 # Per GPU
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda'
# --- DDP Setup ---
def setup_ddp():
"""Initializes the distributed data parallel environment."""
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return rank, local_rank
def cleanup_ddp():
"""Cleans up the distributed data parallel environment."""
dist.destroy_process_group()
# --- Main Training Script ---
def main():
rank, local_rank = setup_ddp()
is_main_process = (rank == 0)
config = TrainConfig()
config.device = f'cuda:{local_rank}'
if is_main_process:
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
if is_main_process:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main_process:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
if is_main_process:
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
model = DDP(model, device_ids=[local_rank])
if is_main_process:
print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
# --- 3. Training Loop ---
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
if is_main_process:
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training", disable=not is_main_process)
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
optimizer.zero_grad()
loss.backward()
optimizer.step()
if is_main_process:
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
if is_main_process:
print("\nTraining finished.")
# --- 4. Final Validation ---
if is_main_process:
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation", disable=not is_main_process)
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
if is_main_process:
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
if is_main_process:
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.module.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter_multigpu.png')
print("\nLoss curves saved to loss_curves_iter_multigpu.png")
cleanup_ddp()
if __name__ == '__main__':
main()