update models and training scripts
This commit is contained in:
20
train.py
20
train.py
@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
||||
import json
|
||||
import argparse
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
# --- Configuration ---
|
||||
@@ -25,6 +25,7 @@ class TrainConfig:
|
||||
n_head = 12
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
model_name = 'TimeAwareGPT2'
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
@@ -59,6 +60,7 @@ def main():
|
||||
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -76,10 +78,11 @@ def main():
|
||||
config.pdrop = args.pdrop
|
||||
config.token_pdrop = args.token_pdrop
|
||||
config.betas = tuple(args.betas)
|
||||
config.model_name = args.model
|
||||
|
||||
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
|
||||
model_filename = f"best_model_{model_suffix}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||
@@ -105,7 +108,12 @@ def main():
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
model_cls = {
|
||||
'TimeAwareGPT2': TimeAwareGPT2,
|
||||
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||
}[config.model_name]
|
||||
|
||||
model = model_cls(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
@@ -235,7 +243,7 @@ def main():
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
# --- Save losses to a txt file ---
|
||||
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
|
||||
losses_filename = f"losses_{model_suffix}.txt"
|
||||
with open(losses_filename, 'w') as f:
|
||||
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||
for i in range(len(train_losses_total)):
|
||||
|
Reference in New Issue
Block a user