feat: Add load_model function and update training script

Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
This commit is contained in:
2025-10-18 11:07:59 +08:00
parent f7356b183c
commit a631ac6d59
2 changed files with 48 additions and 2 deletions

View File

@@ -23,8 +23,8 @@ class TrainConfig:
n_embd = 120 n_embd = 120
n_layer = 12 n_layer = 12
n_head = 12 n_head = 12
pdrop = 0.1 pdrop = 0.0
token_pdrop = 0.1 token_pdrop = 0.0
# Training parameters # Training parameters
max_iter = 200000 max_iter = 200000

View File

@@ -2,6 +2,9 @@ import torch
import numpy as np import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json
from models import TimeAwareGPT2
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
""" """
@@ -102,3 +105,46 @@ class PatientEventDataset(torch.utils.data.Dataset):
time_tensor = torch.tensor(time_stamps, dtype=torch.long) time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path, model_path, vocab_size, device='cpu'):
"""
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
Args:
config_path (str): Path to the JSON configuration file.
model_path (str): Path to the saved model state dictionary (.pt file).
vocab_size (int): The vocabulary size used during training.
device (str): The device to load the model onto ('cpu' or 'cuda').
Returns:
(TimeAwareGPT2): The loaded and initialized model.
"""
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Create a config object from the dictionary
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
config = AttrDict(config_dict)
# Initialize the model with parameters from the config
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
# Load the saved state dictionary
model.load_state_dict(torch.load(model_path, map_location=device))
# Set the model to evaluation mode
model.eval()
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
return model