2025-10-16 14:21:36 +08:00
import torch
import torch . nn as nn
2025-10-17 10:31:12 +08:00
from torch . optim import AdamW
2025-10-16 14:21:36 +08:00
from torch . utils . data import DataLoader
import numpy as np
import math
import tqdm
2025-10-16 15:57:27 +08:00
import matplotlib . pyplot as plt
2025-10-17 14:09:34 +08:00
import json
2025-10-18 10:23:12 +08:00
import argparse
2025-10-16 14:21:36 +08:00
2025-10-22 17:34:06 +08:00
from models import TimeAwareGPT2 , TimeAwareGPT2Learnable , TimeAwareGPT2TemporalConv , CombinedLoss
2025-10-16 14:21:36 +08:00
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig :
# Data parameters
train_data_path = ' ukb_real_train.bin '
val_data_path = ' ukb_real_val.bin '
2025-10-16 18:50:15 +08:00
block_length = 48 # Sequence length
2025-10-16 14:21:36 +08:00
# Model parameters
2025-10-16 18:50:15 +08:00
n_embd = 120
n_layer = 12
n_head = 12
2025-10-16 14:21:36 +08:00
pdrop = 0.1
token_pdrop = 0.1
2025-10-22 08:36:55 +08:00
model_name = ' TimeAwareGPT2 '
2025-10-16 14:21:36 +08:00
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
2025-10-17 13:47:37 +08:00
weight_decay = 2e-1
2025-10-16 14:21:36 +08:00
warmup_epochs = 10
2025-10-17 10:31:12 +08:00
early_stopping_patience = 10
2025-10-18 10:23:12 +08:00
betas = ( 0.9 , 0.99 )
2025-10-16 14:21:36 +08:00
# Loss parameters
# 0 = padding, 1 = "no event"
2025-10-16 17:10:01 +08:00
ignored_token_ids = [ 0 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ] # Example ignored token IDs
2025-10-16 14:21:36 +08:00
# System parameters
device = ' cuda ' if torch . cuda . is_available ( ) else ' cpu '
# --- Main Training Script ---
def main ( ) :
2025-10-18 10:23:12 +08:00
parser = argparse . ArgumentParser ( description = ' Train a Time-Aware GPT-2 model. ' )
parser . add_argument ( ' --n_layer ' , type = int , default = 12 , help = ' Number of transformer layers. ' )
parser . add_argument ( ' --n_embd ' , type = int , default = 120 , help = ' Embedding dimension. ' )
parser . add_argument ( ' --n_head ' , type = int , default = 12 , help = ' Number of attention heads. ' )
parser . add_argument ( ' --max_epoch ' , type = int , default = 200 , help = ' Maximum number of training epochs. ' )
parser . add_argument ( ' --batch_size ' , type = int , default = 128 , help = ' Batch size for training. ' )
parser . add_argument ( ' --lr_initial ' , type = float , default = 6e-4 , help = ' Initial learning rate. ' )
parser . add_argument ( ' --lr_final ' , type = float , default = 6e-5 , help = ' Final learning rate. ' )
parser . add_argument ( ' --weight_decay ' , type = float , default = 2e-1 , help = ' Weight decay for the optimizer. ' )
parser . add_argument ( ' --warmup_epochs ' , type = int , default = 10 , help = ' Number of warmup epochs. ' )
parser . add_argument ( ' --early_stopping_patience ' , type = int , default = 10 , help = ' Patience for early stopping. ' )
parser . add_argument ( ' --pdrop ' , type = float , default = 0.1 , help = ' Dropout probability. ' )
parser . add_argument ( ' --token_pdrop ' , type = float , default = 0.1 , help = ' Token dropout probability. ' )
parser . add_argument ( ' --betas ' , type = float , nargs = 2 , default = [ 0.9 , 0.99 ] , help = ' AdamW betas. ' )
2025-10-22 17:34:06 +08:00
parser . add_argument ( ' --model ' , type = str , choices = [ ' TimeAwareGPT2 ' , ' TimeAwareGPT2Learnable ' , ' TimeAwareGPT2TemporalConv ' ] , default = ' TimeAwareGPT2 ' , help = ' Model architecture to train. ' )
2025-10-18 10:23:12 +08:00
args = parser . parse_args ( )
2025-10-16 14:21:36 +08:00
config = TrainConfig ( )
2025-10-18 10:23:12 +08:00
config . n_layer = args . n_layer
config . n_embd = args . n_embd
config . n_head = args . n_head
config . max_epoch = args . max_epoch
config . batch_size = args . batch_size
config . lr_initial = args . lr_initial
config . lr_final = args . lr_final
config . weight_decay = args . weight_decay
config . warmup_epochs = args . warmup_epochs
config . early_stopping_patience = args . early_stopping_patience
config . pdrop = args . pdrop
config . token_pdrop = args . token_pdrop
config . betas = tuple ( args . betas )
2025-10-22 08:36:55 +08:00
config . model_name = args . model
2025-10-18 10:23:12 +08:00
2025-10-22 08:36:55 +08:00
model_suffix = f " { config . model_name } _n_embd_ { config . n_embd } _n_layer_ { config . n_layer } _n_head_ { config . n_head } "
model_filename = f " best_model_ { model_suffix } .pt "
checkpoint_filename = f " best_model_checkpoint_ { model_suffix } .pt "
2025-10-17 10:44:17 +08:00
2025-10-17 14:09:34 +08:00
# --- 0. Save Configuration ---
2025-10-22 17:56:23 +08:00
# Include model class in config filename for clarity/distinction across architectures
config_filename = f " config_ { config . model_name } _n_embd_ { config . n_embd } _n_layer_ { config . n_layer } _n_head_ { config . n_head } .json "
2025-10-17 14:09:34 +08:00
config_dict = { k : v for k , v in vars ( config ) . items ( ) if not k . startswith ( ' __ ' ) }
with open ( config_filename , ' w ' ) as f :
json . dump ( config_dict , f , indent = 4 )
print ( f " Configuration saved to { config_filename } " )
2025-10-16 14:21:36 +08:00
# --- 1. Data Loading ---
print ( f " Loading data from { config . train_data_path } and { config . val_data_path } ... " )
train_data_arr = np . memmap ( config . train_data_path , dtype = np . uint32 , mode = ' r ' ) . reshape ( - 1 , 3 )
val_data_arr = np . memmap ( config . val_data_path , dtype = np . uint32 , mode = ' r ' ) . reshape ( - 1 , 3 )
# Infer vocab_size from the data (max label + 1)
vocab_size = int ( max ( train_data_arr [ : , 2 ] . max ( ) , val_data_arr [ : , 2 ] . max ( ) ) ) + 1
print ( f " Inferred vocabulary size: { vocab_size } " )
train_dataset = PatientEventDataset ( train_data_arr , config . block_length )
val_dataset = PatientEventDataset ( val_data_arr , config . block_length )
train_loader = DataLoader ( train_dataset , batch_size = config . batch_size , shuffle = True , num_workers = 4 , pin_memory = True )
val_loader = DataLoader ( val_dataset , batch_size = config . batch_size , shuffle = False , num_workers = 4 , pin_memory = True )
# --- 2. Model, Optimizer, and Loss Initialization ---
2025-10-16 16:23:38 +08:00
print ( f " Initializing model on { config . device } ... " )
2025-10-22 08:36:55 +08:00
model_cls = {
' TimeAwareGPT2 ' : TimeAwareGPT2 ,
' TimeAwareGPT2Learnable ' : TimeAwareGPT2Learnable ,
2025-10-22 17:34:06 +08:00
' TimeAwareGPT2TemporalConv ' : TimeAwareGPT2TemporalConv ,
2025-10-22 08:36:55 +08:00
} [ config . model_name ]
model = model_cls (
2025-10-16 14:21:36 +08:00
vocab_size = vocab_size ,
n_embd = config . n_embd ,
n_layer = config . n_layer ,
n_head = config . n_head ,
pdrop = config . pdrop ,
token_pdrop = config . token_pdrop
2025-10-16 16:23:38 +08:00
) . to ( config . device )
2025-10-16 14:21:36 +08:00
2025-10-16 16:23:38 +08:00
print ( f " Model initialized with { model . get_num_params ( ) : .2f } M trainable parameters. " )
2025-10-16 14:21:36 +08:00
loss_fn = CombinedLoss ( config . ignored_token_ids )
2025-10-18 10:23:12 +08:00
optimizer = AdamW ( model . parameters ( ) , lr = config . lr_initial , weight_decay = config . weight_decay , betas = config . betas )
2025-10-16 14:21:36 +08:00
# --- 3. Training Loop ---
best_val_loss = float ( ' inf ' )
patience_counter = 0
2025-10-16 15:57:27 +08:00
# Lists to store losses
train_losses_ce , train_losses_surv , train_losses_total = [ ] , [ ] , [ ]
val_losses_ce , val_losses_surv , val_losses_total = [ ] , [ ] , [ ]
2025-10-16 14:21:36 +08:00
print ( " Starting training... " )
for epoch in range ( config . max_epoch ) :
# --- Learning Rate Scheduling ---
if epoch < config . warmup_epochs :
lr = config . lr_initial
else :
progress = ( epoch - config . warmup_epochs ) / ( config . max_epoch - config . warmup_epochs )
lr = config . lr_final + 0.5 * ( config . lr_initial - config . lr_final ) * ( 1 + math . cos ( math . pi * progress ) )
for param_group in optimizer . param_groups :
param_group [ ' lr ' ] = lr
# --- Training Phase ---
model . train ( )
train_loss_ce_acc , train_loss_surv_acc = 0.0 , 0.0
train_steps = 0
pbar = tqdm . tqdm ( train_loader , desc = f " Epoch { epoch + 1 } / { config . max_epoch } [Train] " )
for event_seq , time_seq in pbar :
2025-10-16 16:23:38 +08:00
event_seq , time_seq = event_seq . to ( config . device ) , time_seq . to ( config . device )
2025-10-16 14:21:36 +08:00
# Prepare inputs and targets
input_events = event_seq [ : , : - 1 ]
input_times = time_seq [ : , : - 1 ]
target_events = event_seq [ : , 1 : ]
target_wait_times = ( time_seq [ : , 1 : ] - time_seq [ : , : - 1 ] ) . float ( )
# Forward pass
logits = model ( input_events , input_times )
loss_ce , loss_survival = loss_fn ( logits , target_events , target_wait_times )
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer . zero_grad ( )
loss . backward ( )
optimizer . step ( )
train_loss_ce_acc + = loss_ce . item ( )
train_loss_surv_acc + = loss_survival . item ( )
train_steps + = 1
pbar . set_postfix ( { ' loss_ce ' : f ' { loss_ce . item ( ) : .4f } ' , ' loss_surv ' : f ' { loss_survival . item ( ) : .4f } ' , ' lr ' : f ' { lr : .2e } ' } )
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
2025-10-16 15:57:27 +08:00
train_losses_ce . append ( avg_train_loss_ce )
train_losses_surv . append ( avg_train_loss_surv )
train_losses_total . append ( avg_train_loss_ce + avg_train_loss_surv )
2025-10-16 14:21:36 +08:00
# --- Validation Phase ---
model . eval ( )
val_loss_ce_acc , val_loss_surv_acc = 0.0 , 0.0
val_steps = 0
with torch . no_grad ( ) :
pbar_val = tqdm . tqdm ( val_loader , desc = f " Epoch { epoch + 1 } / { config . max_epoch } [Val] " )
for event_seq , time_seq in pbar_val :
2025-10-16 16:23:38 +08:00
event_seq , time_seq = event_seq . to ( config . device ) , time_seq . to ( config . device )
2025-10-16 14:21:36 +08:00
input_events = event_seq [ : , : - 1 ]
input_times = time_seq [ : , : - 1 ]
target_events = event_seq [ : , 1 : ]
target_wait_times = ( time_seq [ : , 1 : ] - time_seq [ : , : - 1 ] ) . float ( )
logits = model ( input_events , input_times )
loss_ce , loss_survival = loss_fn ( logits , target_events , target_wait_times )
val_loss_ce_acc + = loss_ce . item ( )
val_loss_surv_acc + = loss_survival . item ( )
val_steps + = 1
pbar_val . set_postfix ( { ' loss_ce ' : f ' { loss_ce . item ( ) : .4f } ' , ' loss_surv ' : f ' { loss_survival . item ( ) : .4f } ' } )
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
2025-10-16 15:57:27 +08:00
val_losses_ce . append ( avg_val_loss_ce )
val_losses_surv . append ( avg_val_loss_surv )
val_losses_total . append ( total_val_loss )
2025-10-16 14:21:36 +08:00
2025-10-17 10:44:17 +08:00
print ( f " Epoch { epoch + 1 } Summary: \n "
f " Train Loss: { avg_train_loss_ce + avg_train_loss_surv : .4f } (CE: { avg_train_loss_ce : .4f } , Surv: { avg_train_loss_surv : .4f } ) \n "
f " Val Loss: { total_val_loss : .4f } (CE: { avg_val_loss_ce : .4f } , Surv: { avg_val_loss_surv : .4f } ) \n "
2025-10-16 14:21:36 +08:00
f " Learning Rate: { lr : .6f } " )
# --- Early Stopping Check ---
if total_val_loss < best_val_loss :
best_val_loss = total_val_loss
patience_counter = 0
2025-10-16 15:57:27 +08:00
print ( f " Validation loss improved to { best_val_loss : .4f } . Saving checkpoint... " )
2025-10-17 10:44:17 +08:00
torch . save ( model . state_dict ( ) , checkpoint_filename )
2025-10-16 14:21:36 +08:00
else :
2025-10-16 15:57:27 +08:00
if epoch > = config . warmup_epochs :
patience_counter + = 1
print ( f " Validation loss did not improve. Patience: { patience_counter } / { config . early_stopping_patience } " )
2025-10-16 14:21:36 +08:00
if patience_counter > = config . early_stopping_patience :
print ( " \n Early stopping triggered due to no improvement in validation loss. " )
break
2025-10-16 15:57:27 +08:00
# --- Save Best Model at the End ---
if best_val_loss != float ( ' inf ' ) :
print ( f " \n Training finished. Loading best model from checkpoint with validation loss { best_val_loss : .4f } . " )
2025-10-17 10:44:17 +08:00
model . load_state_dict ( torch . load ( checkpoint_filename ) )
print ( f " Saving final best model to { model_filename } " )
torch . save ( model . state_dict ( ) , model_filename )
2025-10-16 15:57:27 +08:00
else :
print ( " \n Training finished. No best model to save as validation loss never improved. " )
2025-10-17 10:44:17 +08:00
# --- Save losses to a txt file ---
2025-10-22 08:36:55 +08:00
losses_filename = f " losses_ { model_suffix } .txt "
2025-10-17 10:44:17 +08:00
with open ( losses_filename , ' w ' ) as f :
f . write ( " epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total \n " )
for i in range ( len ( train_losses_total ) ) :
f . write ( f " { i + 1 } , { train_losses_ce [ i ] } , { train_losses_surv [ i ] } , { train_losses_total [ i ] } , { val_losses_ce [ i ] } , { val_losses_surv [ i ] } , { val_losses_total [ i ] } \n " )
print ( f " \n Losses saved to { losses_filename } " )
2025-10-16 15:57:27 +08:00
# --- Plot and Save Loss Curves ---
num_epochs = len ( train_losses_total )
epochs = range ( 1 , num_epochs + 1 )
plt . figure ( figsize = ( 18 , 5 ) )
# Plot CE Loss
plt . subplot ( 1 , 3 , 1 )
plt . plot ( epochs , train_losses_ce , label = ' Train CE ' )
plt . plot ( epochs , val_losses_ce , label = ' Val CE ' )
plt . title ( ' Cross-Entropy Loss ' )
plt . xlabel ( ' Epochs ' )
plt . ylabel ( ' Loss ' )
plt . legend ( )
plt . grid ( True )
# Plot Survival Loss
plt . subplot ( 1 , 3 , 2 )
plt . plot ( epochs , train_losses_surv , label = ' Train Survival ' )
plt . plot ( epochs , val_losses_surv , label = ' Val Survival ' )
plt . title ( ' Survival Loss ' )
plt . xlabel ( ' Epochs ' )
plt . ylabel ( ' Loss ' )
plt . legend ( )
plt . grid ( True )
# Plot Total Loss
plt . subplot ( 1 , 3 , 3 )
plt . plot ( epochs , train_losses_total , label = ' Train Total ' )
plt . plot ( epochs , val_losses_total , label = ' Val Total ' )
plt . title ( ' Total Loss ' )
plt . xlabel ( ' Epochs ' )
plt . ylabel ( ' Loss ' )
plt . legend ( )
plt . grid ( True )
plt . tight_layout ( )
plt . savefig ( ' loss_curves.png ' )
print ( " \n Loss curves saved to loss_curves.png " )
2025-10-16 14:21:36 +08:00
if __name__ == ' __main__ ' :
2025-10-17 10:44:17 +08:00
main ( )