Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly
This commit is contained in:
104
utils.py
104
utils.py
@@ -3,7 +3,8 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
from models import TimeAwareGPT2
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class PatientEventDataset(torch.utils.data.Dataset):
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
@@ -111,49 +112,100 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return event_tensor, time_tensor
|
return event_tensor, time_tensor
|
||||||
|
|
||||||
def load_model(config_path, model_path, vocab_size, device='cpu'):
|
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
|
||||||
"""
|
"""
|
||||||
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
|
Load a trained model based on the training configuration and checkpoint.
|
||||||
|
|
||||||
|
According to train.py, models may be either 'TimeAwareGPT2' or
|
||||||
|
'TimeAwareGPT2Learnable'. This function:
|
||||||
|
- Reads the config JSON to get architecture hyperparameters
|
||||||
|
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
||||||
|
- Infers vocab_size from the checkpoint if not provided
|
||||||
|
- Loads weights and returns the model in eval mode on the requested device
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path (str): Path to the JSON configuration file.
|
config_path: Path to the JSON configuration file saved during training.
|
||||||
model_path (str): Path to the saved model state dictionary (.pt file).
|
model_path: Path to the saved model state dict (.pt).
|
||||||
vocab_size (int): The vocabulary size used during training.
|
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
|
||||||
device (str): The device to load the model onto ('cpu' or 'cuda').
|
device: 'cpu' or 'cuda'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(TimeAwareGPT2): The loaded and initialized model.
|
torch.nn.Module: Loaded model ready for inference.
|
||||||
"""
|
"""
|
||||||
|
# 1) Read config
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
|
|
||||||
print(f"Model config: {config_dict}")
|
# Access config entries with attribute-style access while staying tolerant to missing keys
|
||||||
|
|
||||||
# Create a config object from the dictionary
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
def __init__(self, *args, **kwargs):
|
def __getattr__(self, item):
|
||||||
super(AttrDict, self).__init__(*args, **kwargs)
|
try:
|
||||||
self.__dict__ = self
|
return self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(item)
|
||||||
|
|
||||||
config = AttrDict(config_dict)
|
config = AttrDict(config_dict)
|
||||||
|
|
||||||
# Initialize the model with parameters from the config
|
# 2) Decide model class (train.py supports two variants)
|
||||||
model = TimeAwareGPT2(
|
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
|
||||||
|
model_cls = {
|
||||||
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
}.get(model_name, TimeAwareGPT2)
|
||||||
|
|
||||||
|
# 3) Infer vocab_size from checkpoint if not provided
|
||||||
|
if vocab_size is None:
|
||||||
|
state = torch.load(model_path, map_location='cpu')
|
||||||
|
# Try typical parameter names first
|
||||||
|
if 'wte.weight' in state:
|
||||||
|
vocab_size = state['wte.weight'].shape[0]
|
||||||
|
elif 'head.weight' in state:
|
||||||
|
vocab_size = state['head.weight'].shape[0]
|
||||||
|
else:
|
||||||
|
# Fallback: try to find any (V, D) weight that likely encodes vocab
|
||||||
|
candidate = None
|
||||||
|
for k, v in state.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
||||||
|
# Heuristic: the larger dim is probably vocab
|
||||||
|
V = max(v.shape)
|
||||||
|
if candidate is None or V > candidate:
|
||||||
|
candidate = V
|
||||||
|
if candidate is None:
|
||||||
|
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
|
||||||
|
vocab_size = candidate
|
||||||
|
|
||||||
|
# 4) Build model from config
|
||||||
|
# Be tolerant to configs from earlier runs that did not save some fields
|
||||||
|
n_embd = getattr(config, 'n_embd')
|
||||||
|
n_layer = getattr(config, 'n_layer')
|
||||||
|
n_head = getattr(config, 'n_head')
|
||||||
|
pdrop = getattr(config, 'pdrop', 0.1)
|
||||||
|
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
||||||
|
|
||||||
|
model = model_cls(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_embd=config.n_embd,
|
n_embd=n_embd,
|
||||||
n_layer=config.n_layer,
|
n_layer=n_layer,
|
||||||
n_head=config.n_head,
|
n_head=n_head,
|
||||||
pdrop=config.pdrop,
|
pdrop=pdrop,
|
||||||
token_pdrop=config.token_pdrop
|
token_pdrop=token_pdrop,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# Load the saved state dictionary
|
# 5) Load weights
|
||||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"Warning: Missing keys when loading state_dict: {missing}")
|
||||||
|
if unexpected:
|
||||||
|
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
|
||||||
|
|
||||||
# Set the model to evaluation mode
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
try:
|
||||||
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
num_params_m = model.get_num_params()
|
||||||
|
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user