config: Add weight decay to training configuration
Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
This commit is contained in:
3
train.py
3
train.py
@@ -29,6 +29,7 @@ class TrainConfig:
|
|||||||
batch_size = 128
|
batch_size = 128
|
||||||
lr_initial = 6e-4
|
lr_initial = 6e-4
|
||||||
lr_final = 6e-5
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
warmup_epochs = 10
|
warmup_epochs = 10
|
||||||
early_stopping_patience = 10
|
early_stopping_patience = 10
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ def main():
|
|||||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial)
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
|
||||||
|
|
||||||
# --- 3. Training Loop ---
|
# --- 3. Training Loop ---
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
Reference in New Issue
Block a user