feat: adapt train.py to multi-GPU environment

This commit is contained in:
2025-10-16 16:16:15 +08:00
parent 4181ead03a
commit b7aad7a774

View File

@@ -42,6 +42,7 @@ class TrainConfig:
# --- Main Training Script ---
def main():
config = TrainConfig()
device = torch.device(config.device)
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -59,7 +60,7 @@ def main():
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
print(f"Initializing model on {device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
@@ -67,9 +68,16 @@ def main():
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
# --- Multi-GPU Support ---
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
model.to(device)
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
@@ -101,7 +109,7 @@ def main():
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
@@ -138,7 +146,7 @@ def main():
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
@@ -170,7 +178,9 @@ def main():
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
# Save the underlying model state_dict when using DataParallel
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(model_state, 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
@@ -183,9 +193,11 @@ def main():
# --- Save Best Model at the End ---
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
# Load the state dict into the base model, not the DataParallel wrapper
base_model = model.module if isinstance(model, nn.DataParallel) else model
base_model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt")
torch.save(model.state_dict(), 'best_model.pt')
torch.save(base_model.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")