From a631ac6d5970b044896b88fd5151b9a7ee9d4297 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 18 Oct 2025 11:07:59 +0800 Subject: [PATCH] feat: Add load_model function and update training script Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files. The `train_iter.py` script was also modified, likely to incorporate or test this new functionality. --- train_iter.py | 4 ++-- utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/train_iter.py b/train_iter.py index fa04ecb..d64d15a 100644 --- a/train_iter.py +++ b/train_iter.py @@ -23,8 +23,8 @@ class TrainConfig: n_embd = 120 n_layer = 12 n_head = 12 - pdrop = 0.1 - token_pdrop = 0.1 + pdrop = 0.0 + token_pdrop = 0.0 # Training parameters max_iter = 200000 diff --git a/utils.py b/utils.py index 5128dc8..a3954d9 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,9 @@ import torch import numpy as np import random from collections import defaultdict +import json +from models import TimeAwareGPT2 + class PatientEventDataset(torch.utils.data.Dataset): """ @@ -102,3 +105,46 @@ class PatientEventDataset(torch.utils.data.Dataset): time_tensor = torch.tensor(time_stamps, dtype=torch.long) return event_tensor, time_tensor + +def load_model(config_path, model_path, vocab_size, device='cpu'): + """ + Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary. + + Args: + config_path (str): Path to the JSON configuration file. + model_path (str): Path to the saved model state dictionary (.pt file). + vocab_size (int): The vocabulary size used during training. + device (str): The device to load the model onto ('cpu' or 'cuda'). + + Returns: + (TimeAwareGPT2): The loaded and initialized model. + """ + with open(config_path, 'r') as f: + config_dict = json.load(f) + + # Create a config object from the dictionary + class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + config = AttrDict(config_dict) + + # Initialize the model with parameters from the config + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(device) + + # Load the saved state dictionary + model.load_state_dict(torch.load(model_path, map_location=device)) + + # Set the model to evaluation mode + model.eval() + + print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") + return model