Revert "feat: adapt train.py to multi-GPU environment"
This reverts commit b7aad7a774
.
This commit is contained in:
28
train.py
28
train.py
@@ -42,7 +42,6 @@ class TrainConfig:
|
|||||||
# --- Main Training Script ---
|
# --- Main Training Script ---
|
||||||
def main():
|
def main():
|
||||||
config = TrainConfig()
|
config = TrainConfig()
|
||||||
device = torch.device(config.device)
|
|
||||||
|
|
||||||
# --- 1. Data Loading ---
|
# --- 1. Data Loading ---
|
||||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
@@ -60,7 +59,7 @@ def main():
|
|||||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||||
|
|
||||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
print(f"Initializing model on {device}...")
|
print(f"Initializing model on {config.device}...")
|
||||||
model = TimeAwareGPT2(
|
model = TimeAwareGPT2(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_embd=config.n_embd,
|
n_embd=config.n_embd,
|
||||||
@@ -68,16 +67,9 @@ def main():
|
|||||||
n_head=config.n_head,
|
n_head=config.n_head,
|
||||||
pdrop=config.pdrop,
|
pdrop=config.pdrop,
|
||||||
token_pdrop=config.token_pdrop
|
token_pdrop=config.token_pdrop
|
||||||
)
|
).to(config.device)
|
||||||
|
|
||||||
# --- Multi-GPU Support ---
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
print(f"Using {torch.cuda.device_count()} GPUs!")
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
|
|
||||||
|
|
||||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||||
@@ -109,7 +101,7 @@ def main():
|
|||||||
|
|
||||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||||
for event_seq, time_seq in pbar:
|
for event_seq, time_seq in pbar:
|
||||||
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
# Prepare inputs and targets
|
# Prepare inputs and targets
|
||||||
input_events = event_seq[:, :-1]
|
input_events = event_seq[:, :-1]
|
||||||
@@ -146,7 +138,7 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||||
for event_seq, time_seq in pbar_val:
|
for event_seq, time_seq in pbar_val:
|
||||||
event_seq, time_seq = event_seq.to(device), time_seq.to(device)
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
input_events = event_seq[:, :-1]
|
input_events = event_seq[:, :-1]
|
||||||
input_times = time_seq[:, :-1]
|
input_times = time_seq[:, :-1]
|
||||||
@@ -178,9 +170,7 @@ def main():
|
|||||||
best_val_loss = total_val_loss
|
best_val_loss = total_val_loss
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||||
# Save the underlying model state_dict when using DataParallel
|
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
||||||
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
|
|
||||||
torch.save(model_state, 'best_model_checkpoint.pt')
|
|
||||||
else:
|
else:
|
||||||
if epoch >= config.warmup_epochs:
|
if epoch >= config.warmup_epochs:
|
||||||
patience_counter += 1
|
patience_counter += 1
|
||||||
@@ -193,11 +183,9 @@ def main():
|
|||||||
# --- Save Best Model at the End ---
|
# --- Save Best Model at the End ---
|
||||||
if best_val_loss != float('inf'):
|
if best_val_loss != float('inf'):
|
||||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||||
# Load the state dict into the base model, not the DataParallel wrapper
|
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
||||||
base_model = model.module if isinstance(model, nn.DataParallel) else model
|
|
||||||
base_model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
|
||||||
print("Saving final best model to best_model.pt")
|
print("Saving final best model to best_model.pt")
|
||||||
torch.save(base_model.state_dict(), 'best_model.pt')
|
torch.save(model.state_dict(), 'best_model.pt')
|
||||||
else:
|
else:
|
||||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user