Revert "feat: adapt train.py to multi-GPU environment"
This reverts commit b7aad7a774
.
This commit is contained in:
28
train.py
28
train.py
@@ -42,7 +42,6 @@ class TrainConfig:
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
config = TrainConfig()
|
||||
device = torch.device(config.device)
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
@@ -60,7 +59,7 @@ def main():
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
print(f"Initializing model on {device}...")
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
@@ -68,16 +67,9 @@ def main():
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
)
|
||||
).to(config.device)
|
||||
|
||||
# --- Multi-GPU Support ---
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f"Using {torch.cuda.device_count()} GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
model.to(device)
|
||||
|
||||
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
@@ -109,7 +101,7 @@ def main():
|
||||
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||
for event_seq, time_seq in pbar:
|
||||
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
@@ -146,7 +138,7 @@ def main():
|
||||
with torch.no_grad():
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||
for event_seq, time_seq in pbar_val:
|
||||
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
@@ -178,9 +170,7 @@ def main():
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
# Save the underlying model state_dict when using DataParallel
|
||||
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
|
||||
torch.save(model_state, 'best_model_checkpoint.pt')
|
||||
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
||||
else:
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
@@ -193,11 +183,9 @@ def main():
|
||||
# --- Save Best Model at the End ---
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
# Load the state dict into the base model, not the DataParallel wrapper
|
||||
base_model = model.module if isinstance(model, nn.DataParallel) else model
|
||||
base_model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
||||
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
||||
print("Saving final best model to best_model.pt")
|
||||
torch.save(base_model.state_dict(), 'best_model.pt')
|
||||
torch.save(model.state_dict(), 'best_model.pt')
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
|
Reference in New Issue
Block a user