From 3390bc025e4db2b24504d8678f01bc1ffdeb1713 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 18 Oct 2025 10:05:37 +0800 Subject: [PATCH] feat: Add iteration-based training scripts (single and multi-GPU) --- train_iter.py | 218 ++++++++++++++++++++++++++++++++++++ train_iter_multigpu.py | 247 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 465 insertions(+) create mode 100644 train_iter.py create mode 100644 train_iter_multigpu.py diff --git a/train_iter.py b/train_iter.py new file mode 100644 index 0000000..fa04ecb --- /dev/null +++ b/train_iter.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.utils.data import DataLoader +import numpy as np +import math +import tqdm +import matplotlib.pyplot as plt +import json +import itertools + +from models import TimeAwareGPT2, CombinedLoss +from utils import PatientEventDataset + +# --- Configuration --- +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 48 # Sequence length + + # Model parameters + n_embd = 120 + n_layer = 12 + n_head = 12 + pdrop = 0.1 + token_pdrop = 0.1 + + # Training parameters + max_iter = 200000 + batch_size = 128 + lr_initial = 6e-4 + lr_final = 6e-5 + weight_decay = 2e-1 + warmup_iter = 1000 + + # Loss parameters + # 0 = padding, 1 = "no event" + ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs + + # System parameters + device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# --- Main Training Script --- +def main(): + config = TrainConfig() + + model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt" + + # --- 0. Save Configuration --- + config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json" + config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} + with open(config_filename, 'w') as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration saved to {config_filename}") + + # --- 1. Data Loading --- + print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + # Infer vocab_size from the data (max label + 1) + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) + + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) + train_iter_loader = iter(itertools.cycle(train_loader)) + + # --- 2. Model, Optimizer, and Loss Initialization --- + print(f"Initializing model on {config.device}...") + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(config.device) + + print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") + + loss_fn = CombinedLoss(config.ignored_token_ids) + optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99)) + + # --- 3. Training Loop --- + + # Lists to store losses + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + + print("Starting training...") + pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training") + for iter_num in pbar: + # --- Learning Rate Scheduling --- + if iter_num < config.warmup_iter: + lr = config.lr_initial + else: + progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter) + lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # --- Training Step --- + model.train() + + event_seq, time_seq = next(train_iter_loader) + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + # Prepare inputs and targets + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + # Forward pass + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_losses_ce.append(loss_ce.item()) + train_losses_surv.append(loss_survival.item()) + train_losses_total.append(loss.item()) + + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + print("\nTraining finished.") + + # --- 4. Final Validation --- + print("Running final validation...") + model.eval() + val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 + val_steps = 0 + + with torch.no_grad(): + pbar_val = tqdm.tqdm(val_loader, desc="Final Validation") + for event_seq, time_seq in pbar_val: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.item() + val_loss_surv_acc += loss_survival.item() + val_steps += 1 + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) + + avg_val_loss_ce = val_loss_ce_acc / val_steps + avg_val_loss_surv = val_loss_surv_acc / val_steps + total_val_loss = avg_val_loss_ce + avg_val_loss_surv + + print(f"Final Validation Summary: \n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})") + + # --- 5. Save Model --- + print(f"Saving final model to {model_filename}") + torch.save(model.state_dict(), model_filename) + + # --- 6. Save and Plot Losses --- + losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt" + with open(losses_filename, 'w') as f: + f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n") + for i in range(len(train_losses_total)): + f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n") + print(f"\nLosses saved to {losses_filename}") + + # Plot and Save Loss Curves + iterations = range(1, len(train_losses_total) + 1) + + plt.figure(figsize=(18, 5)) + + # Plot CE Loss + plt.subplot(1, 3, 1) + plt.plot(iterations, train_losses_ce, label='Train CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Survival Loss + plt.subplot(1, 3, 2) + plt.plot(iterations, train_losses_surv, label='Train Survival') + plt.title('Survival Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Total Loss + plt.subplot(1, 3, 3) + plt.plot(iterations, train_losses_total, label='Train Total') + plt.title('Total Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves_iter.png') + print("\nLoss curves saved to loss_curves_iter.png") + + +if __name__ == '__main__': + main() diff --git a/train_iter_multigpu.py b/train_iter_multigpu.py new file mode 100644 index 0000000..8c9119b --- /dev/null +++ b/train_iter_multigpu.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.utils.data import DataLoader +import numpy as np +import math +import tqdm +import matplotlib.pyplot as plt +import json +import itertools +import os +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler + +from models import TimeAwareGPT2, CombinedLoss +from utils import PatientEventDataset + +# --- Configuration --- +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 48 # Sequence length + + # Model parameters + n_embd = 120 + n_layer = 12 + n_head = 12 + pdrop = 0.1 + token_pdrop = 0.1 + + # Training parameters + max_iter = 200000 + batch_size = 128 # Per GPU + lr_initial = 6e-4 + lr_final = 6e-5 + weight_decay = 2e-1 + warmup_iter = 1000 + + # Loss parameters + # 0 = padding, 1 = "no event" + ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs + + # System parameters + device = 'cuda' + +# --- DDP Setup --- +def setup_ddp(): + """Initializes the distributed data parallel environment.""" + dist.init_process_group(backend='nccl') + rank = dist.get_rank() + local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(local_rank) + return rank, local_rank + +def cleanup_ddp(): + """Cleans up the distributed data parallel environment.""" + dist.destroy_process_group() + +# --- Main Training Script --- +def main(): + rank, local_rank = setup_ddp() + is_main_process = (rank == 0) + + config = TrainConfig() + config.device = f'cuda:{local_rank}' + + if is_main_process: + model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.pt" + # --- 0. Save Configuration --- + config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.json" + config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} + with open(config_filename, 'w') as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration saved to {config_filename}") + + # --- 1. Data Loading --- + if is_main_process: + print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + if is_main_process: + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) + + train_sampler = DistributedSampler(train_dataset) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True) + train_iter_loader = iter(itertools.cycle(train_loader)) + + # --- 2. Model, Optimizer, and Loss Initialization --- + if is_main_process: + print(f"Initializing model on {config.device}...") + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(config.device) + model = DDP(model, device_ids=[local_rank]) + + if is_main_process: + print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.") + + loss_fn = CombinedLoss(config.ignored_token_ids) + optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) + + # --- 3. Training Loop --- + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + + if is_main_process: + print("Starting training...") + pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training", disable=not is_main_process) + for iter_num in pbar: + # --- Learning Rate Scheduling --- + if iter_num < config.warmup_iter: + lr = config.lr_initial + else: + progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter) + lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # --- Training Step --- + model.train() + + event_seq, time_seq = next(train_iter_loader) + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if is_main_process: + train_losses_ce.append(loss_ce.item()) + train_losses_surv.append(loss_survival.item()) + train_losses_total.append(loss.item()) + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + if is_main_process: + print("\nTraining finished.") + + # --- 4. Final Validation --- + if is_main_process: + print("Running final validation...") + model.eval() + val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 + val_steps = 0 + + with torch.no_grad(): + pbar_val = tqdm.tqdm(val_loader, desc="Final Validation", disable=not is_main_process) + for event_seq, time_seq in pbar_val: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.item() + val_loss_surv_acc += loss_survival.item() + val_steps += 1 + if is_main_process: + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) + + avg_val_loss_ce = val_loss_ce_acc / val_steps + avg_val_loss_surv = val_loss_surv_acc / val_steps + total_val_loss = avg_val_loss_ce + avg_val_loss_surv + + if is_main_process: + print(f"Final Validation Summary: \n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})") + + # --- 5. Save Model --- + print(f"Saving final model to {model_filename}") + torch.save(model.module.state_dict(), model_filename) + + # --- 6. Save and Plot Losses --- + losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.txt" + with open(losses_filename, 'w') as f: + f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n") + for i in range(len(train_losses_total)): + f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n") + print(f"\nLosses saved to {losses_filename}") + + # Plot and Save Loss Curves + iterations = range(1, len(train_losses_total) + 1) + + plt.figure(figsize=(18, 5)) + + # Plot CE Loss + plt.subplot(1, 3, 1) + plt.plot(iterations, train_losses_ce, label='Train CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Survival Loss + plt.subplot(1, 3, 2) + plt.plot(iterations, train_losses_surv, label='Train Survival') + plt.title('Survival Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Total Loss + plt.subplot(1, 3, 3) + plt.plot(iterations, train_losses_total, label='Train Total') + plt.title('Total Loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves_iter_multigpu.png') + print("\nLoss curves saved to loss_curves_iter_multigpu.png") + + cleanup_ddp() + +if __name__ == '__main__': + main()