From d760c45baf4621de2040965e3a6a4baf734f53b4 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 17 Oct 2025 14:09:34 +0800 Subject: [PATCH] feat: Add multi-GPU training and improve config/ignore Add train_multigpu.py for distributed data parallel training. Update train.py to save the training configuration to a JSON file. Generalize .gitignore to exclude all *.pt checkpoint files. Delete obsolete train_dpp.py file. --- .gitignore | 2 +- train.py | 8 + train_dpp.py | 400 ---------------------------------------------- train_multigpu.py | 273 +++++++++++++++++++++++++++++++ 4 files changed, 282 insertions(+), 401 deletions(-) delete mode 100644 train_dpp.py create mode 100644 train_multigpu.py diff --git a/.gitignore b/.gitignore index 13cceec..2dc07a7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ __pycache__/ # Model checkpoints -best_model_checkpoint.pt +*.pt # Large data files ukb_delphi.txt diff --git a/train.py b/train.py index 53a9837..edd59e2 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import numpy as np import math import tqdm import matplotlib.pyplot as plt +import json from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset @@ -47,6 +48,13 @@ def main(): model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + # --- 0. Save Configuration --- + config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" + config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} + with open(config_filename, 'w') as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration saved to {config_filename}") + # --- 1. Data Loading --- print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) diff --git a/train_dpp.py b/train_dpp.py deleted file mode 100644 index bb43e2f..0000000 --- a/train_dpp.py +++ /dev/null @@ -1,400 +0,0 @@ -# train.py (DDP-ready) -import os -import math -import argparse -import numpy as np -import tqdm -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -import torch.distributed as dist -from torch.optim import Adam -from torch.utils.data import DataLoader, DistributedSampler - -from models import TimeAwareGPT2, CombinedLoss -from utils import PatientEventDataset - - -# --- Configuration --- -class TrainConfig: - # Data parameters - train_data_path = 'ukb_real_train.bin' - val_data_path = 'ukb_real_val.bin' - block_length = 24 # Sequence length - - # Model parameters - n_embd = 256 - n_layer = 8 - n_head = 8 - pdrop = 0.1 - token_pdrop = 0.1 - - # Training parameters - max_epoch = 200 - batch_size = 128 - lr_initial = 6e-4 - lr_final = 6e-5 - warmup_epochs = 10 - early_stopping_patience = 5 - - # Loss parameters - # 0 = padding, 1 = "no event" - ignored_token_ids = [0, 1] - - # System parameters (device 将在 main() 内按 local_rank 动态设置) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - -def setup_distributed(backend: str = "nccl"): - """ - 如果由 torchrun 启动且 WORLD_SIZE>1,则初始化分布式。 - 返回 (is_distributed, world_size, rank, local_rank) - """ - world_size = int(os.environ.get("WORLD_SIZE", "1")) - is_distributed = world_size > 1 - if is_distributed: - if not dist.is_initialized(): - dist.init_process_group(backend=backend, init_method="env://") - rank = dist.get_rank() - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - else: - rank = 0 - local_rank = 0 - return is_distributed, world_size, rank, local_rank - - -def cleanup_distributed(): - if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() - - -def all_reduce_mean(value: float, device, world_size: int): - """ - value 是 Python float(本进程的和/均值),返回所有进程平均后的 float。 - """ - tensor = torch.tensor([value], dtype=torch.float32, device=device) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor /= world_size - return float(tensor.item()) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"]) - parser.add_argument("--seed", type=int, default=42) - args = parser.parse_args() - - # 分布式初始化 - is_dist, world_size, rank, local_rank = setup_distributed(args.backend) - - # 基本环境 - torch.manual_seed(args.seed + rank) - np.random.seed(args.seed + rank) - torch.backends.cudnn.benchmark = True - - config = TrainConfig() - device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") - config.device = device - - is_main = (rank == 0) - - # --- 1. Data Loading --- - if is_main: - print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") - train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) - val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) - - # Infer vocab_size from the data (max label + 1) - vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 - if is_main: - print(f"Inferred vocabulary size: {vocab_size}") - - train_dataset = PatientEventDataset(train_data_arr, config.block_length) - val_dataset = PatientEventDataset(val_data_arr, config.block_length) - - # 分布式采样器 - if is_dist: - train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) - val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) - else: - train_sampler = None - val_sampler = None - - train_loader = DataLoader( - train_dataset, - batch_size=config.batch_size, - shuffle=(train_sampler is None), - sampler=train_sampler, - num_workers=4, - pin_memory=True, - drop_last=False, - persistent_workers=True if 4 > 0 else False, - ) - val_loader = DataLoader( - val_dataset, - batch_size=config.batch_size, - shuffle=False, - sampler=val_sampler, - num_workers=4, - pin_memory=True, - drop_last=False, - persistent_workers=True if 4 > 0 else False, - ) - - # --- 2. Model, Optimizer, and Loss Initialization --- - if is_main: - print(f"Initializing model on {config.device}...") - model = TimeAwareGPT2( - vocab_size=vocab_size, - n_embd=config.n_embd, - n_layer=config.n_layer, - n_head=config.n_head, - pdrop=config.pdrop, - token_pdrop=config.token_pdrop - ).to(device) - - if is_main and hasattr(model, "get_num_params"): - print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") - - loss_fn = CombinedLoss(config.ignored_token_ids) - optimizer = Adam(model.parameters(), lr=config.lr_initial) - - # DDP 包装 - if is_dist: - model = nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank] if device.type == "cuda" else None, - output_device=local_rank if device.type == "cuda" else None, - find_unused_parameters=False, - ) - - # --- 3. Training Loop --- - best_val_loss = float('inf') - patience_counter = 0 - - # 只在主进程收集与画图 - train_losses_ce, train_losses_surv, train_losses_total = [], [], [] - val_losses_ce, val_losses_surv, val_losses_total = [], [], [] - - if is_main: - print("Starting training...") - - stop_training = False - - for epoch in range(config.max_epoch): - # 设置 epoch 给分布式采样器,确保跨 epoch shuffle - if is_dist: - train_sampler.set_epoch(epoch) - - # --- Learning Rate Scheduling --- - if epoch < config.warmup_epochs: - lr = config.lr_initial - else: - progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) - lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - # --- Training Phase --- - if is_main: - pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") - else: - pbar = train_loader # 非主进程禁用 tqdm - - model.train() - train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 - train_steps = 0 - - for batch in pbar: - event_seq, time_seq = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) - - # Prepare inputs and targets - input_events = event_seq[:, :-1] - input_times = time_seq[:, :-1] - target_events = event_seq[:, 1:] - target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() - - # Forward pass - logits = model(input_events, input_times) - loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) - loss = loss_ce + loss_survival - - # Backward pass and optimization - optimizer.zero_grad(set_to_none=True) - loss.backward() - optimizer.step() - - train_loss_ce_acc += float(loss_ce.item()) - train_loss_surv_acc += float(loss_survival.item()) - train_steps += 1 - - if is_main and isinstance(pbar, tqdm.tqdm): - pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) - - # 进程内均值 - avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1) - avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1) - - # 所有进程平均 - if is_dist: - avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size) - avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size) - else: - avg_train_loss_ce = avg_train_loss_ce_local - avg_train_loss_surv = avg_train_loss_surv_local - - if is_main: - train_losses_ce.append(avg_train_loss_ce) - train_losses_surv.append(avg_train_loss_surv) - train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) - - # --- Validation Phase --- - if is_main: - pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") - else: - pbar_val = val_loader - - model.eval() - val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 - val_steps = 0 - - with torch.no_grad(): - for batch in pbar_val: - event_seq, time_seq = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) - - input_events = event_seq[:, :-1] - input_times = time_seq[:, :-1] - target_events = event_seq[:, 1:] - target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() - - logits = model(input_events, input_times) - loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) - - val_loss_ce_acc += float(loss_ce.item()) - val_loss_surv_acc += float(loss_survival.item()) - val_steps += 1 - - if is_main and isinstance(pbar_val, tqdm.tqdm): - pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) - - avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1) - avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1) - - if is_dist: - avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size) - avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size) - else: - avg_val_loss_ce = avg_val_loss_ce_local - avg_val_loss_surv = avg_val_loss_surv_local - - total_val_loss = avg_val_loss_ce + avg_val_loss_surv - - # 主进程打印与记录 - if is_main: - print(f"Epoch {epoch+1} Summary: \n" - f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" - f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" - f" Learning Rate: {lr:.6f}") - val_losses_ce.append(avg_val_loss_ce) - val_losses_surv.append(avg_val_loss_surv) - val_losses_total.append(total_val_loss) - - # --- Early Stopping Check (基于聚合后的 total_val_loss) --- - improved = False - if is_main: - if total_val_loss < best_val_loss: - best_val_loss = total_val_loss - patience_counter = 0 - improved = True - print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") - # DDP: 保存 module.state_dict() - state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict() - torch.save(state_dict, 'best_model_checkpoint.pt') - else: - if epoch >= config.warmup_epochs: - patience_counter += 1 - print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") - stop_training = patience_counter >= config.early_stopping_patience - - # 把 improved/stop 广播到所有进程,确保一致退出 - if is_dist: - flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32) - dist.broadcast(flag_tensor, src=0) - stop_training = bool(int(flag_tensor.item())) - - if stop_training: - if is_main: - print("\nEarly stopping triggered due to no improvement in validation loss.") - break - - # --- Save Best Model at the End (只主进程) --- - if is_main: - if best_val_loss != float('inf'): - print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") - # 为了易用,这里在主进程上重新构建单卡模型加载权重再保存 - model_single = TimeAwareGPT2( - vocab_size=vocab_size, - n_embd=config.n_embd, - n_layer=config.n_layer, - n_head=config.n_head, - pdrop=config.pdrop, - token_pdrop=config.token_pdrop - ).to('cpu') - model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) - print("Saving final best model to best_model.pt") - torch.save(model_single.state_dict(), 'best_model.pt') - else: - print("\nTraining finished. No best model to save as validation loss never improved.") - - # --- Plot and Save Loss Curves --- - num_epochs = len(train_losses_total) - if num_epochs > 0: - epochs = range(1, num_epochs + 1) - plt.figure(figsize=(18, 5)) - - # Plot CE Loss - plt.subplot(1, 3, 1) - plt.plot(epochs, train_losses_ce, label='Train CE') - plt.plot(epochs, val_losses_ce, label='Val CE') - plt.title('Cross-Entropy Loss') - plt.xlabel('Epochs') - plt.ylabel('Loss') - plt.legend() - plt.grid(True) - - # Plot Survival Loss - plt.subplot(1, 3, 2) - plt.plot(epochs, train_losses_surv, label='Train Survival') - plt.plot(epochs, val_losses_surv, label='Val Survival') - plt.title('Survival Loss') - plt.xlabel('Epochs') - plt.ylabel('Loss') - plt.legend() - plt.grid(True) - - # Plot Total Loss - plt.subplot(1, 3, 3) - plt.plot(epochs, train_losses_total, label='Train Total') - plt.plot(epochs, val_losses_total, label='Val Total') - plt.title('Total Loss') - plt.xlabel('Epochs') - plt.ylabel('Loss') - plt.legend() - plt.grid(True) - - plt.tight_layout() - plt.savefig('loss_curves.png') - print("\nLoss curves saved to loss_curves.png") - - # 清理分布式 - cleanup_distributed() - - -if __name__ == '__main__': - main() diff --git a/train_multigpu.py b/train_multigpu.py new file mode 100644 index 0000000..0bfaf7b --- /dev/null +++ b/train_multigpu.py @@ -0,0 +1,273 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.utils.data import DataLoader +import numpy as np +import math +import tqdm +import matplotlib.pyplot as plt +import json +import os +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler + +from models import TimeAwareGPT2, CombinedLoss +from utils import PatientEventDataset + +# --- Configuration --- +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 48 # Sequence length + + # Model parameters + n_embd = 120 + n_layer = 12 + n_head = 12 + pdrop = 0.1 + token_pdrop = 0.1 + + # Training parameters + max_epoch = 200 + batch_size = 128 + lr_initial = 6e-4 + lr_final = 6e-5 + weight_decay = 2e-1 + warmup_epochs = 10 + early_stopping_patience = 10 + + # Loss parameters + ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + + # System parameters (will be updated by DDP setup) + device = 'cuda' + +# --- DDP Setup --- +def setup_ddp(): + """Initializes the distributed data parallel environment.""" + dist.init_process_group(backend='nccl') + rank = dist.get_rank() + local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(local_rank) + return rank, local_rank + +def cleanup_ddp(): + """Cleans up the distributed data parallel environment.""" + dist.destroy_process_group() + +# --- Main Training Script --- +def main(): + rank, local_rank = setup_ddp() + is_main_process = (rank == 0) + + config = TrainConfig() + config.device = f'cuda:{local_rank}' + + if is_main_process: + model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + + # --- 0. Save Configuration --- + config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" + config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} + with open(config_filename, 'w') as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration saved to {config_filename}") + + # --- 1. Data Loading --- + if is_main_process: + print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + if is_main_process: + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) + + train_sampler = DistributedSampler(train_dataset) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True) + + # --- 2. Model, Optimizer, and Loss Initialization --- + if is_main_process: + print(f"Initializing model on {config.device}...") + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(config.device) + model = DDP(model, device_ids=[local_rank]) + + if is_main_process: + print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.") + + loss_fn = CombinedLoss(config.ignored_token_ids) + optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) + + # --- 3. Training Loop --- + best_val_loss = float('inf') + patience_counter = 0 + + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + val_losses_ce, val_losses_surv, val_losses_total = [], [], [] + + if is_main_process: + print("Starting training...") + for epoch in range(config.max_epoch): + train_sampler.set_epoch(epoch) # Important for shuffling + + if epoch < config.warmup_epochs: + lr = config.lr_initial + else: + progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) + lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + model.train() + train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 + train_steps = 0 + + pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]", disable=not is_main_process) + for event_seq, time_seq in pbar: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events, input_times = event_seq[:, :-1], time_seq[:, :-1] + target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss_ce_acc += loss_ce.item() + train_loss_surv_acc += loss_survival.item() + train_steps += 1 + if is_main_process: + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + avg_train_loss_ce = train_loss_ce_acc / train_steps + avg_train_loss_surv = train_loss_surv_acc / train_steps + train_losses_ce.append(avg_train_loss_ce) + train_losses_surv.append(avg_train_loss_surv) + train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) + + model.eval() + val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 + val_steps = 0 + + with torch.no_grad(): + pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]", disable=not is_main_process) + for event_seq, time_seq in pbar_val: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events, input_times = event_seq[:, :-1], time_seq[:, :-1] + target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.item() + val_loss_surv_acc += loss_survival.item() + val_steps += 1 + if is_main_process: + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) + + avg_val_loss_ce = val_loss_ce_acc / val_steps + avg_val_loss_surv = val_loss_surv_acc / val_steps + total_val_loss = avg_val_loss_ce + avg_val_loss_surv + val_losses_ce.append(avg_val_loss_ce) + val_losses_surv.append(avg_val_loss_surv) + val_losses_total.append(total_val_loss) + + if is_main_process: + print(f"Epoch {epoch+1} Summary: \n" + f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" + f" Learning Rate: {lr:.6f}") + + if total_val_loss < best_val_loss: + best_val_loss = total_val_loss + patience_counter = 0 + print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") + torch.save(model.module.state_dict(), checkpoint_filename) + else: + if epoch >= config.warmup_epochs: + patience_counter += 1 + print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") + + if patience_counter >= config.early_stopping_patience: + if is_main_process: + print("\nEarly stopping triggered due to no improvement in validation loss.") + break + + if is_main_process: + if best_val_loss != float('inf'): + print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") + # Load the best weights into the module before saving the final model + model.module.load_state_dict(torch.load(checkpoint_filename)) + print(f"Saving final best model to {model_filename}") + torch.save(model.module.state_dict(), model_filename) + else: + print("\nTraining finished. No best model to save as validation loss never improved.") + + losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" + with open(losses_filename, 'w') as f: + f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") + for i in range(len(train_losses_total)): + f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") + print(f"\nLosses saved to {losses_filename}") + + num_epochs = len(train_losses_total) + epochs = range(1, num_epochs + 1) + plt.figure(figsize=(18, 5)) + + plt.subplot(1, 3, 1) + plt.plot(epochs, train_losses_ce, label='Train CE') + plt.plot(epochs, val_losses_ce, label='Val CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.subplot(1, 3, 2) + plt.plot(epochs, train_losses_surv, label='Train Survival') + plt.plot(epochs, val_losses_surv, label='Val Survival') + plt.title('Survival Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.subplot(1, 3, 3) + plt.plot(epochs, train_losses_total, label='Train Total') + plt.plot(epochs, val_losses_total, label='Val Total') + plt.title('Total Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves.png') + print("\nLoss curves saved to loss_curves.png") + + cleanup_ddp() + +if __name__ == '__main__': + main()