utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint
This commit is contained in:
75
utils.py
75
utils.py
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PatientEventDataset(torch.utils.data.Dataset):
|
||||
@@ -112,21 +112,21 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
|
||||
return event_tensor, time_tensor
|
||||
|
||||
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
|
||||
def load_model(config_path: str, device: str = 'cpu'):
|
||||
"""
|
||||
Load a trained model based on the training configuration and checkpoint.
|
||||
Load a trained model based on the training configuration, inferring the
|
||||
checkpoint filename from the configuration.
|
||||
|
||||
According to train.py, models may be either 'TimeAwareGPT2' or
|
||||
'TimeAwareGPT2Learnable'. This function:
|
||||
- Reads the config JSON to get architecture hyperparameters
|
||||
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
||||
- Infers vocab_size from the checkpoint if not provided
|
||||
- Infers the checkpoint path from the config values
|
||||
- Infers vocab_size from the checkpoint
|
||||
- Loads weights and returns the model in eval mode on the requested device
|
||||
|
||||
Args:
|
||||
config_path: Path to the JSON configuration file saved during training.
|
||||
model_path: Path to the saved model state dict (.pt).
|
||||
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
|
||||
device: 'cpu' or 'cuda'.
|
||||
|
||||
Returns:
|
||||
@@ -153,32 +153,47 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
|
||||
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||
}.get(model_name, TimeAwareGPT2)
|
||||
|
||||
# 3) Infer vocab_size from checkpoint if not provided
|
||||
if vocab_size is None:
|
||||
state = torch.load(model_path, map_location='cpu')
|
||||
# Try typical parameter names first
|
||||
if 'wte.weight' in state:
|
||||
vocab_size = state['wte.weight'].shape[0]
|
||||
elif 'head.weight' in state:
|
||||
vocab_size = state['head.weight'].shape[0]
|
||||
else:
|
||||
# Fallback: try to find any (V, D) weight that likely encodes vocab
|
||||
candidate = None
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
||||
# Heuristic: the larger dim is probably vocab
|
||||
V = max(v.shape)
|
||||
if candidate is None or V > candidate:
|
||||
candidate = V
|
||||
if candidate is None:
|
||||
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
|
||||
vocab_size = candidate
|
||||
|
||||
# 4) Build model from config
|
||||
# Be tolerant to configs from earlier runs that did not save some fields
|
||||
# 3) Infer checkpoint filename from config
|
||||
n_embd = getattr(config, 'n_embd')
|
||||
n_layer = getattr(config, 'n_layer')
|
||||
n_head = getattr(config, 'n_head')
|
||||
|
||||
# Newer naming (includes model_name) used by train.py when model_name is present
|
||||
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
|
||||
|
||||
# Older naming (without model_name) matches existing repo files
|
||||
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
|
||||
|
||||
# Prefer file that exists on disk
|
||||
if os.path.exists(ckpt_with_model):
|
||||
model_path = ckpt_with_model
|
||||
elif os.path.exists(ckpt_legacy):
|
||||
model_path = ckpt_legacy
|
||||
else:
|
||||
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
|
||||
model_path = ckpt_with_model
|
||||
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
|
||||
|
||||
# 4) Infer vocab_size from checkpoint
|
||||
state_preview = torch.load(model_path, map_location='cpu')
|
||||
if 'wte.weight' in state_preview:
|
||||
vocab_size = state_preview['wte.weight'].shape[0]
|
||||
elif 'head.weight' in state_preview:
|
||||
vocab_size = state_preview['head.weight'].shape[0]
|
||||
else:
|
||||
candidate = None
|
||||
for k, v in state_preview.items():
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
||||
V = max(v.shape)
|
||||
if candidate is None or V > candidate:
|
||||
candidate = V
|
||||
if candidate is None:
|
||||
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
|
||||
vocab_size = candidate
|
||||
|
||||
# 5) Build model from config (tolerant to missing fields)
|
||||
pdrop = getattr(config, 'pdrop', 0.1)
|
||||
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
||||
|
||||
@@ -191,7 +206,7 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
|
||||
token_pdrop=token_pdrop,
|
||||
).to(device)
|
||||
|
||||
# 5) Load weights
|
||||
# 6) Load weights
|
||||
state_dict = torch.load(model_path, map_location=device)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
Reference in New Issue
Block a user