utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint

This commit is contained in:
2025-10-22 11:39:10 +08:00
parent 92a5bd4a83
commit 6801e5bdbb

View File

@@ -1,10 +1,10 @@
import os
import torch import torch
import numpy as np import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from typing import Optional
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
@@ -112,21 +112,21 @@ class PatientEventDataset(torch.utils.data.Dataset):
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'): def load_model(config_path: str, device: str = 'cpu'):
""" """
Load a trained model based on the training configuration and checkpoint. Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function: 'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters - Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent) - Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers vocab_size from the checkpoint if not provided - Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device - Loads weights and returns the model in eval mode on the requested device
Args: Args:
config_path: Path to the JSON configuration file saved during training. config_path: Path to the JSON configuration file saved during training.
model_path: Path to the saved model state dict (.pt).
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
device: 'cpu' or 'cuda'. device: 'cpu' or 'cuda'.
Returns: Returns:
@@ -153,32 +153,47 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2) }.get(model_name, TimeAwareGPT2)
# 3) Infer vocab_size from checkpoint if not provided # 3) Infer checkpoint filename from config
if vocab_size is None: n_embd = getattr(config, 'n_embd')
state = torch.load(model_path, map_location='cpu') n_layer = getattr(config, 'n_layer')
# Try typical parameter names first n_head = getattr(config, 'n_head')
if 'wte.weight' in state:
vocab_size = state['wte.weight'].shape[0] # Newer naming (includes model_name) used by train.py when model_name is present
elif 'head.weight' in state: suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
vocab_size = state['head.weight'].shape[0] ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else: else:
# Fallback: try to find any (V, D) weight that likely encodes vocab
candidate = None candidate = None
for k, v in state.items(): for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2: if isinstance(v, torch.Tensor) and v.ndim == 2:
# Heuristic: the larger dim is probably vocab
V = max(v.shape) V = max(v.shape)
if candidate is None or V > candidate: if candidate is None or V > candidate:
candidate = V candidate = V
if candidate is None: if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.") raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate vocab_size = candidate
# 4) Build model from config # 5) Build model from config (tolerant to missing fields)
# Be tolerant to configs from earlier runs that did not save some fields
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
pdrop = getattr(config, 'pdrop', 0.1) pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1) token_pdrop = getattr(config, 'token_pdrop', 0.1)
@@ -191,7 +206,7 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
token_pdrop=token_pdrop, token_pdrop=token_pdrop,
).to(device) ).to(device)
# 5) Load weights # 6) Load weights
state_dict = torch.load(model_path, map_location=device) state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False) missing, unexpected = model.load_state_dict(state_dict, strict=False)