diff --git a/train.py b/train.py index 275f12c..53a9837 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ class TrainConfig: batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 + weight_decay = 2e-1 warmup_epochs = 10 early_stopping_patience = 10 @@ -75,7 +76,7 @@ def main(): print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) - optimizer = AdamW(model.parameters(), lr=config.lr_initial) + optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) # --- 3. Training Loop --- best_val_loss = float('inf')