feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files. The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
This commit is contained in:
@@ -23,8 +23,8 @@ class TrainConfig:
|
|||||||
n_embd = 120
|
n_embd = 120
|
||||||
n_layer = 12
|
n_layer = 12
|
||||||
n_head = 12
|
n_head = 12
|
||||||
pdrop = 0.1
|
pdrop = 0.0
|
||||||
token_pdrop = 0.1
|
token_pdrop = 0.0
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
max_iter = 200000
|
max_iter = 200000
|
||||||
|
46
utils.py
46
utils.py
@@ -2,6 +2,9 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from models import TimeAwareGPT2
|
||||||
|
|
||||||
|
|
||||||
class PatientEventDataset(torch.utils.data.Dataset):
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
@@ -102,3 +105,46 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
|||||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||||
|
|
||||||
return event_tensor, time_tensor
|
return event_tensor, time_tensor
|
||||||
|
|
||||||
|
def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||||
|
"""
|
||||||
|
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): Path to the JSON configuration file.
|
||||||
|
model_path (str): Path to the saved model state dictionary (.pt file).
|
||||||
|
vocab_size (int): The vocabulary size used during training.
|
||||||
|
device (str): The device to load the model onto ('cpu' or 'cuda').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(TimeAwareGPT2): The loaded and initialized model.
|
||||||
|
"""
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
# Create a config object from the dictionary
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
config = AttrDict(config_dict)
|
||||||
|
|
||||||
|
# Initialize the model with parameters from the config
|
||||||
|
model = TimeAwareGPT2(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=config.n_embd,
|
||||||
|
n_layer=config.n_layer,
|
||||||
|
n_head=config.n_head,
|
||||||
|
pdrop=config.pdrop,
|
||||||
|
token_pdrop=config.token_pdrop
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Load the saved state dictionary
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||||
|
|
||||||
|
# Set the model to evaluation mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
||||||
|
return model
|
||||||
|
Reference in New Issue
Block a user