diff --git a/train.py b/train.py index 92fa286..a745fa3 100644 --- a/train.py +++ b/train.py @@ -42,6 +42,7 @@ class TrainConfig: # --- Main Training Script --- def main(): config = TrainConfig() + device = torch.device(config.device) # --- 1. Data Loading --- print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") @@ -59,7 +60,7 @@ def main(): val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) # --- 2. Model, Optimizer, and Loss Initialization --- - print(f"Initializing model on {config.device}...") + print(f"Initializing model on {device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, @@ -67,9 +68,16 @@ def main(): n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop - ).to(config.device) + ) - print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") + # --- Multi-GPU Support --- + if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs!") + model = nn.DataParallel(model) + + model.to(device) + + print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) @@ -101,7 +109,7 @@ def main(): pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") for event_seq, time_seq in pbar: - event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + event_seq, time_seq = event_seq.to(device), time_seq.to(device) # Prepare inputs and targets input_events = event_seq[:, :-1] @@ -138,7 +146,7 @@ def main(): with torch.no_grad(): pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") for event_seq, time_seq in pbar_val: - event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + event_seq, time_seq = event_seq.to(device), time_seq.to(device) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] @@ -170,7 +178,9 @@ def main(): best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") - torch.save(model.state_dict(), 'best_model_checkpoint.pt') + # Save the underlying model state_dict when using DataParallel + model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() + torch.save(model_state, 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 @@ -183,9 +193,11 @@ def main(): # --- Save Best Model at the End --- if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") - model.load_state_dict(torch.load('best_model_checkpoint.pt')) + # Load the state dict into the base model, not the DataParallel wrapper + base_model = model.module if isinstance(model, nn.DataParallel) else model + base_model.load_state_dict(torch.load('best_model_checkpoint.pt')) print("Saving final best model to best_model.pt") - torch.save(model.state_dict(), 'best_model.pt') + torch.save(base_model.state_dict(), 'best_model.pt') else: print("\nTraining finished. No best model to save as validation loss never improved.")