feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files. The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
This commit is contained in:
@@ -23,8 +23,8 @@ class TrainConfig:
|
||||
n_embd = 120
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
pdrop = 0.0
|
||||
token_pdrop = 0.0
|
||||
|
||||
# Training parameters
|
||||
max_iter = 200000
|
||||
|
46
utils.py
46
utils.py
@@ -2,6 +2,9 @@ import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from models import TimeAwareGPT2
|
||||
|
||||
|
||||
class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
@@ -102,3 +105,46 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||
|
||||
return event_tensor, time_tensor
|
||||
|
||||
def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||
"""
|
||||
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the JSON configuration file.
|
||||
model_path (str): Path to the saved model state dictionary (.pt file).
|
||||
vocab_size (int): The vocabulary size used during training.
|
||||
device (str): The device to load the model onto ('cpu' or 'cuda').
|
||||
|
||||
Returns:
|
||||
(TimeAwareGPT2): The loaded and initialized model.
|
||||
"""
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Create a config object from the dictionary
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
config = AttrDict(config_dict)
|
||||
|
||||
# Initialize the model with parameters from the config
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(device)
|
||||
|
||||
# Load the saved state dictionary
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
||||
return model
|
||||
|
Reference in New Issue
Block a user