diff --git a/train.py b/train.py index edd59e2..1bcc023 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import math import tqdm import matplotlib.pyplot as plt import json +import argparse from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset @@ -33,6 +34,7 @@ class TrainConfig: weight_decay = 2e-1 warmup_epochs = 10 early_stopping_patience = 10 + betas = (0.9, 0.99) # Loss parameters # 0 = padding, 1 = "no event" @@ -43,7 +45,38 @@ class TrainConfig: # --- Main Training Script --- def main(): + parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.') + parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.') + parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.') + parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.') + parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.') + parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.') + parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.') + parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.') + parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.') + parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.') + parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.') + parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') + parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') + parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') + + args = parser.parse_args() + config = TrainConfig() + config.n_layer = args.n_layer + config.n_embd = args.n_embd + config.n_head = args.n_head + config.max_epoch = args.max_epoch + config.batch_size = args.batch_size + config.lr_initial = args.lr_initial + config.lr_final = args.lr_final + config.weight_decay = args.weight_decay + config.warmup_epochs = args.warmup_epochs + config.early_stopping_patience = args.early_stopping_patience + config.pdrop = args.pdrop + config.token_pdrop = args.token_pdrop + config.betas = tuple(args.betas) + model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" @@ -84,7 +117,7 @@ def main(): print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) - optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) + optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas) # --- 3. Training Loop --- best_val_loss = float('inf')