Revert "feat: adapt train.py to multi-GPU environment"

This reverts commit b7aad7a774.
This commit is contained in:
2025-10-16 16:23:38 +08:00
parent 2b20299e36
commit c7296381b8

View File

@@ -42,7 +42,6 @@ class TrainConfig:
# --- Main Training Script --- # --- Main Training Script ---
def main(): def main():
config = TrainConfig() config = TrainConfig()
device = torch.device(config.device)
# --- 1. Data Loading --- # --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -60,7 +59,7 @@ def main():
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model = TimeAwareGPT2(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
@@ -68,16 +67,9 @@ def main():
n_head=config.n_head, n_head=config.n_head,
pdrop=config.pdrop, pdrop=config.pdrop,
token_pdrop=config.token_pdrop token_pdrop=config.token_pdrop
) ).to(config.device)
# --- Multi-GPU Support --- print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
model.to(device)
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = Adam(model.parameters(), lr=config.lr_initial)
@@ -109,7 +101,7 @@ def main():
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar: for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(device), time_seq.to(device) event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets # Prepare inputs and targets
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
@@ -146,7 +138,7 @@ def main():
with torch.no_grad(): with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val: for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(device), time_seq.to(device) event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1] input_times = time_seq[:, :-1]
@@ -178,9 +170,7 @@ def main():
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# Save the underlying model state_dict when using DataParallel torch.save(model.state_dict(), 'best_model_checkpoint.pt')
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(model_state, 'best_model_checkpoint.pt')
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
@@ -193,11 +183,9 @@ def main():
# --- Save Best Model at the End --- # --- Save Best Model at the End ---
if best_val_loss != float('inf'): if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# Load the state dict into the base model, not the DataParallel wrapper model.load_state_dict(torch.load('best_model_checkpoint.pt'))
base_model = model.module if isinstance(model, nn.DataParallel) else model
base_model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt") print("Saving final best model to best_model.pt")
torch.save(base_model.state_dict(), 'best_model.pt') torch.save(model.state_dict(), 'best_model.pt')
else: else:
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")