Revert "feat: adapt train.py to multi-GPU environment"

This reverts commit b7aad7a774.
This commit is contained in:
2025-10-16 16:23:38 +08:00
parent 2b20299e36
commit c7296381b8

View File

@@ -42,7 +42,6 @@ class TrainConfig:
# --- Main Training Script ---
def main():
config = TrainConfig()
device = torch.device(config.device)
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -60,7 +59,7 @@ def main():
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {device}...")
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
@@ -68,16 +67,9 @@ def main():
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
)
).to(config.device)
# --- Multi-GPU Support ---
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
model.to(device)
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
@@ -109,7 +101,7 @@ def main():
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
@@ -146,7 +138,7 @@ def main():
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
@@ -178,9 +170,7 @@ def main():
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# Save the underlying model state_dict when using DataParallel
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(model_state, 'best_model_checkpoint.pt')
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
@@ -193,11 +183,9 @@ def main():
# --- Save Best Model at the End ---
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# Load the state dict into the base model, not the DataParallel wrapper
base_model = model.module if isinstance(model, nn.DataParallel) else model
base_model.load_state_dict(torch.load('best_model_checkpoint.pt'))
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt")
torch.save(base_model.state_dict(), 'best_model.pt')
torch.save(model.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")