refactor: Use AdamW optimizer and increase early stopping patience
This commit is contained in:
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch
|
||||
numpy
|
||||
tqdm
|
||||
matplotlib
|
6
train.py
6
train.py
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
@@ -30,7 +30,7 @@ class TrainConfig:
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 5
|
||||
early_stopping_patience = 10
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
|
Reference in New Issue
Block a user