utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint

This commit is contained in:
2025-10-22 11:39:10 +08:00
parent 92a5bd4a83
commit 6801e5bdbb

View File

@@ -1,10 +1,10 @@
import os
import torch
import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from typing import Optional
class PatientEventDataset(torch.utils.data.Dataset):
@@ -112,21 +112,21 @@ class PatientEventDataset(torch.utils.data.Dataset):
return event_tensor, time_tensor
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
def load_model(config_path: str, device: str = 'cpu'):
"""
Load a trained model based on the training configuration and checkpoint.
Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers vocab_size from the checkpoint if not provided
- Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
model_path: Path to the saved model state dict (.pt).
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
device: 'cpu' or 'cuda'.
Returns:
@@ -153,32 +153,47 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer vocab_size from checkpoint if not provided
if vocab_size is None:
state = torch.load(model_path, map_location='cpu')
# Try typical parameter names first
if 'wte.weight' in state:
vocab_size = state['wte.weight'].shape[0]
elif 'head.weight' in state:
vocab_size = state['head.weight'].shape[0]
# 3) Infer checkpoint filename from config
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# Newer naming (includes model_name) used by train.py when model_name is present
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else:
# Fallback: try to find any (V, D) weight that likely encodes vocab
candidate = None
for k, v in state.items():
for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
# Heuristic: the larger dim is probably vocab
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate
# 4) Build model from config
# Be tolerant to configs from earlier runs that did not save some fields
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# 5) Build model from config (tolerant to missing fields)
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
@@ -191,7 +206,7 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
token_pdrop=token_pdrop,
).to(device)
# 5) Load weights
# 6) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)