Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly

This commit is contained in:
2025-10-22 11:30:59 +08:00
parent bd88daa8c2
commit dfdf64da9a

104
utils.py
View File

@@ -3,7 +3,8 @@ import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from typing import Optional
class PatientEventDataset(torch.utils.data.Dataset):
@@ -111,49 +112,100 @@ class PatientEventDataset(torch.utils.data.Dataset):
return event_tensor, time_tensor
def load_model(config_path, model_path, vocab_size, device='cpu'):
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
"""
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
Load a trained model based on the training configuration and checkpoint.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers vocab_size from the checkpoint if not provided
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path (str): Path to the JSON configuration file.
model_path (str): Path to the saved model state dictionary (.pt file).
vocab_size (int): The vocabulary size used during training.
device (str): The device to load the model onto ('cpu' or 'cuda').
config_path: Path to the JSON configuration file saved during training.
model_path: Path to the saved model state dict (.pt).
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
device: 'cpu' or 'cuda'.
Returns:
(TimeAwareGPT2): The loaded and initialized model.
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
print(f"Model config: {config_dict}")
# Create a config object from the dictionary
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# Initialize the model with parameters from the config
model = TimeAwareGPT2(
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer vocab_size from checkpoint if not provided
if vocab_size is None:
state = torch.load(model_path, map_location='cpu')
# Try typical parameter names first
if 'wte.weight' in state:
vocab_size = state['wte.weight'].shape[0]
elif 'head.weight' in state:
vocab_size = state['head.weight'].shape[0]
else:
# Fallback: try to find any (V, D) weight that likely encodes vocab
candidate = None
for k, v in state.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
# Heuristic: the larger dim is probably vocab
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
vocab_size = candidate
# 4) Build model from config
# Be tolerant to configs from earlier runs that did not save some fields
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# Load the saved state dictionary
model.load_state_dict(torch.load(model_path, map_location=device))
# 5) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
# Set the model to evaluation mode
model.eval()
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model