Refactor: Improve attention mechanism and early stopping
- Refactor the self-attention mechanism in `models.py` to use `nn.MultiheadAttention` for better performance and clarity. - Disable early stopping check during warmup epochs in `train.py` to improve training stability.
This commit is contained in:
72
train.py
72
train.py
@@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
@@ -14,7 +15,7 @@ class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 256 # Sequence length
|
||||
block_length = 24 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
@@ -76,6 +77,11 @@ def main():
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
# Lists to store losses
|
||||
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
||||
|
||||
print("Starting training...")
|
||||
for epoch in range(config.max_epoch):
|
||||
# --- Learning Rate Scheduling ---
|
||||
@@ -120,6 +126,9 @@ def main():
|
||||
|
||||
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
||||
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
||||
train_losses_ce.append(avg_train_loss_ce)
|
||||
train_losses_surv.append(avg_train_loss_surv)
|
||||
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
|
||||
|
||||
# --- Validation Phase ---
|
||||
model.eval()
|
||||
@@ -147,6 +156,9 @@ def main():
|
||||
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
val_losses_ce.append(avg_val_loss_ce)
|
||||
val_losses_surv.append(avg_val_loss_surv)
|
||||
val_losses_total.append(total_val_loss)
|
||||
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
@@ -157,14 +169,66 @@ def main():
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.")
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
||||
else:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
# --- Save Best Model at the End ---
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
||||
print("Saving final best model to best_model.pt")
|
||||
torch.save(model.state_dict(), 'best_model.pt')
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
# --- Plot and Save Loss Curves ---
|
||||
num_epochs = len(train_losses_total)
|
||||
epochs = range(1, num_epochs + 1)
|
||||
|
||||
plt.figure(figsize=(18, 5))
|
||||
|
||||
# Plot CE Loss
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(epochs, train_losses_ce, label='Train CE')
|
||||
plt.plot(epochs, val_losses_ce, label='Val CE')
|
||||
plt.title('Cross-Entropy Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Survival Loss
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
||||
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
||||
plt.title('Survival Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Total Loss
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(epochs, train_losses_total, label='Train Total')
|
||||
plt.plot(epochs, val_losses_total, label='Val Total')
|
||||
plt.title('Total Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('loss_curves.png')
|
||||
print("\nLoss curves saved to loss_curves.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Reference in New Issue
Block a user