update models and training scripts

This commit is contained in:
2025-10-22 08:36:55 +08:00
parent e348086e52
commit bd88daa8c2
2 changed files with 56 additions and 90 deletions

View File

@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, CombinedLoss
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
@@ -25,6 +25,7 @@ class TrainConfig:
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters
max_epoch = 200
@@ -59,6 +60,7 @@ def main():
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
args = parser.parse_args()
@@ -76,10 +78,11 @@ def main():
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
config.model_name = args.model
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
@@ -105,7 +108,12 @@ def main():
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}[config.model_name]
model = model_cls(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
@@ -235,7 +243,7 @@ def main():
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):