Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly

This commit is contained in:
2025-10-22 11:30:59 +08:00
parent bd88daa8c2
commit dfdf64da9a

104
utils.py
View File

@@ -3,7 +3,8 @@ import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json import json
from models import TimeAwareGPT2 from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from typing import Optional
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
@@ -111,49 +112,100 @@ class PatientEventDataset(torch.utils.data.Dataset):
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path, model_path, vocab_size, device='cpu'): def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
""" """
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary. Load a trained model based on the training configuration and checkpoint.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers vocab_size from the checkpoint if not provided
- Loads weights and returns the model in eval mode on the requested device
Args: Args:
config_path (str): Path to the JSON configuration file. config_path: Path to the JSON configuration file saved during training.
model_path (str): Path to the saved model state dictionary (.pt file). model_path: Path to the saved model state dict (.pt).
vocab_size (int): The vocabulary size used during training. vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
device (str): The device to load the model onto ('cpu' or 'cuda'). device: 'cpu' or 'cuda'.
Returns: Returns:
(TimeAwareGPT2): The loaded and initialized model. torch.nn.Module: Loaded model ready for inference.
""" """
# 1) Read config
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config_dict = json.load(f) config_dict = json.load(f)
print(f"Model config: {config_dict}") # Access config entries with attribute-style access while staying tolerant to missing keys
# Create a config object from the dictionary
class AttrDict(dict): class AttrDict(dict):
def __init__(self, *args, **kwargs): def __getattr__(self, item):
super(AttrDict, self).__init__(*args, **kwargs) try:
self.__dict__ = self return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict) config = AttrDict(config_dict)
# Initialize the model with parameters from the config # 2) Decide model class (train.py supports two variants)
model = TimeAwareGPT2( model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer vocab_size from checkpoint if not provided
if vocab_size is None:
state = torch.load(model_path, map_location='cpu')
# Try typical parameter names first
if 'wte.weight' in state:
vocab_size = state['wte.weight'].shape[0]
elif 'head.weight' in state:
vocab_size = state['head.weight'].shape[0]
else:
# Fallback: try to find any (V, D) weight that likely encodes vocab
candidate = None
for k, v in state.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
# Heuristic: the larger dim is probably vocab
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
vocab_size = candidate
# 4) Build model from config
# Be tolerant to configs from earlier runs that did not save some fields
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=n_embd,
n_layer=config.n_layer, n_layer=n_layer,
n_head=config.n_head, n_head=n_head,
pdrop=config.pdrop, pdrop=pdrop,
token_pdrop=config.token_pdrop token_pdrop=token_pdrop,
).to(device) ).to(device)
# Load the saved state dictionary # 5) Load weights
model.load_state_dict(torch.load(model_path, map_location=device)) state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
# Set the model to evaluation mode
model.eval() model.eval()
try:
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model return model