From 02d84a7ecaeecabbce415fc27b2630ba5c17d344 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 17 Oct 2025 10:31:12 +0800 Subject: [PATCH] refactor: Use AdamW optimizer and increase early stopping patience --- requirements.txt | 4 ++++ train.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..534d633 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch +numpy +tqdm +matplotlib diff --git a/train.py b/train.py index cebc6cf..1d9ed0c 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from torch.optim import Adam +from torch.optim import AdamW from torch.utils.data import DataLoader import numpy as np import math @@ -30,7 +30,7 @@ class TrainConfig: lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 - early_stopping_patience = 5 + early_stopping_patience = 10 # Loss parameters # 0 = padding, 1 = "no event" @@ -72,7 +72,7 @@ def main(): print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) - optimizer = Adam(model.parameters(), lr=config.lr_initial) + optimizer = AdamW(model.parameters(), lr=config.lr_initial) # --- 3. Training Loop --- best_val_loss = float('inf')