config: Add weight decay to training configuration

Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
This commit is contained in:
2025-10-17 13:47:37 +08:00
parent d4d25ac9c7
commit 053f86f4da

View File

@@ -29,6 +29,7 @@ class TrainConfig:
batch_size = 128 batch_size = 128
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 10 early_stopping_patience = 10
@@ -75,7 +76,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')