feat: Add command-line arguments to train.py
This commit is contained in:
35
train.py
35
train.py
@@ -7,6 +7,7 @@ import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import argparse
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
@@ -33,6 +34,7 @@ class TrainConfig:
|
||||
weight_decay = 2e-1
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 10
|
||||
betas = (0.9, 0.99)
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
@@ -43,7 +45,38 @@ class TrainConfig:
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
|
||||
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
|
||||
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
|
||||
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
|
||||
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
|
||||
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
|
||||
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
|
||||
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
|
||||
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
|
||||
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
|
||||
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TrainConfig()
|
||||
config.n_layer = args.n_layer
|
||||
config.n_embd = args.n_embd
|
||||
config.n_head = args.n_head
|
||||
config.max_epoch = args.max_epoch
|
||||
config.batch_size = args.batch_size
|
||||
config.lr_initial = args.lr_initial
|
||||
config.lr_final = args.lr_final
|
||||
config.weight_decay = args.weight_decay
|
||||
config.warmup_epochs = args.warmup_epochs
|
||||
config.early_stopping_patience = args.early_stopping_patience
|
||||
config.pdrop = args.pdrop
|
||||
config.token_pdrop = args.token_pdrop
|
||||
config.betas = tuple(args.betas)
|
||||
|
||||
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
@@ -84,7 +117,7 @@ def main():
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
|
Reference in New Issue
Block a user