refactor: Use AdamW optimizer and increase early stopping patience

This commit is contained in:
2025-10-17 10:31:12 +08:00
parent cb7575a229
commit 02d84a7eca
2 changed files with 7 additions and 3 deletions

4
requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
torch
numpy
tqdm
matplotlib

View File

@@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import AdamW
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
import math import math
@@ -30,7 +30,7 @@ class TrainConfig:
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 10
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
@@ -72,7 +72,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')