feat: Add command-line arguments to train.py

This commit is contained in:
2025-10-18 10:23:12 +08:00
parent 3390bc025e
commit f7356b183c

View File

@@ -7,6 +7,7 @@ import math
import tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import json import json
import argparse
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
@@ -33,6 +34,7 @@ class TrainConfig:
weight_decay = 2e-1 weight_decay = 2e-1
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 10 early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
@@ -43,7 +45,38 @@ class TrainConfig:
# --- Main Training Script --- # --- Main Training Script ---
def main(): def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
args = parser.parse_args()
config = TrainConfig() config = TrainConfig()
config.n_layer = args.n_layer
config.n_embd = args.n_embd
config.n_head = args.n_head
config.max_epoch = args.max_epoch
config.batch_size = args.batch_size
config.lr_initial = args.lr_initial
config.lr_final = args.lr_final
config.weight_decay = args.weight_decay
config.warmup_epochs = args.warmup_epochs
config.early_stopping_patience = args.early_stopping_patience
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
@@ -84,7 +117,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')