feat: Add command-line arguments to train.py

This commit is contained in:
2025-10-18 10:23:12 +08:00
parent 3390bc025e
commit f7356b183c

View File

@@ -7,6 +7,7 @@ import math
import tqdm
import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
@@ -33,6 +34,7 @@ class TrainConfig:
weight_decay = 2e-1
warmup_epochs = 10
early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters
# 0 = padding, 1 = "no event"
@@ -43,7 +45,38 @@ class TrainConfig:
# --- Main Training Script ---
def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
args = parser.parse_args()
config = TrainConfig()
config.n_layer = args.n_layer
config.n_embd = args.n_embd
config.n_head = args.n_head
config.max_epoch = args.max_epoch
config.batch_size = args.batch_size
config.lr_initial = args.lr_initial
config.lr_final = args.lr_final
config.weight_decay = args.weight_decay
config.warmup_epochs = args.warmup_epochs
config.early_stopping_patience = args.early_stopping_patience
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
@@ -84,7 +117,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
# --- 3. Training Loop ---
best_val_loss = float('inf')