utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint
This commit is contained in:
75
utils.py
75
utils.py
@@ -1,10 +1,10 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class PatientEventDataset(torch.utils.data.Dataset):
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
@@ -112,21 +112,21 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return event_tensor, time_tensor
|
return event_tensor, time_tensor
|
||||||
|
|
||||||
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
|
def load_model(config_path: str, device: str = 'cpu'):
|
||||||
"""
|
"""
|
||||||
Load a trained model based on the training configuration and checkpoint.
|
Load a trained model based on the training configuration, inferring the
|
||||||
|
checkpoint filename from the configuration.
|
||||||
|
|
||||||
According to train.py, models may be either 'TimeAwareGPT2' or
|
According to train.py, models may be either 'TimeAwareGPT2' or
|
||||||
'TimeAwareGPT2Learnable'. This function:
|
'TimeAwareGPT2Learnable'. This function:
|
||||||
- Reads the config JSON to get architecture hyperparameters
|
- Reads the config JSON to get architecture hyperparameters
|
||||||
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
||||||
- Infers vocab_size from the checkpoint if not provided
|
- Infers the checkpoint path from the config values
|
||||||
|
- Infers vocab_size from the checkpoint
|
||||||
- Loads weights and returns the model in eval mode on the requested device
|
- Loads weights and returns the model in eval mode on the requested device
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path: Path to the JSON configuration file saved during training.
|
config_path: Path to the JSON configuration file saved during training.
|
||||||
model_path: Path to the saved model state dict (.pt).
|
|
||||||
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
|
|
||||||
device: 'cpu' or 'cuda'.
|
device: 'cpu' or 'cuda'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -153,32 +153,47 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
|
|||||||
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
}.get(model_name, TimeAwareGPT2)
|
}.get(model_name, TimeAwareGPT2)
|
||||||
|
|
||||||
# 3) Infer vocab_size from checkpoint if not provided
|
# 3) Infer checkpoint filename from config
|
||||||
if vocab_size is None:
|
|
||||||
state = torch.load(model_path, map_location='cpu')
|
|
||||||
# Try typical parameter names first
|
|
||||||
if 'wte.weight' in state:
|
|
||||||
vocab_size = state['wte.weight'].shape[0]
|
|
||||||
elif 'head.weight' in state:
|
|
||||||
vocab_size = state['head.weight'].shape[0]
|
|
||||||
else:
|
|
||||||
# Fallback: try to find any (V, D) weight that likely encodes vocab
|
|
||||||
candidate = None
|
|
||||||
for k, v in state.items():
|
|
||||||
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
|
||||||
# Heuristic: the larger dim is probably vocab
|
|
||||||
V = max(v.shape)
|
|
||||||
if candidate is None or V > candidate:
|
|
||||||
candidate = V
|
|
||||||
if candidate is None:
|
|
||||||
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
|
|
||||||
vocab_size = candidate
|
|
||||||
|
|
||||||
# 4) Build model from config
|
|
||||||
# Be tolerant to configs from earlier runs that did not save some fields
|
|
||||||
n_embd = getattr(config, 'n_embd')
|
n_embd = getattr(config, 'n_embd')
|
||||||
n_layer = getattr(config, 'n_layer')
|
n_layer = getattr(config, 'n_layer')
|
||||||
n_head = getattr(config, 'n_head')
|
n_head = getattr(config, 'n_head')
|
||||||
|
|
||||||
|
# Newer naming (includes model_name) used by train.py when model_name is present
|
||||||
|
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||||
|
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
|
||||||
|
|
||||||
|
# Older naming (without model_name) matches existing repo files
|
||||||
|
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||||
|
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
|
||||||
|
|
||||||
|
# Prefer file that exists on disk
|
||||||
|
if os.path.exists(ckpt_with_model):
|
||||||
|
model_path = ckpt_with_model
|
||||||
|
elif os.path.exists(ckpt_legacy):
|
||||||
|
model_path = ckpt_legacy
|
||||||
|
else:
|
||||||
|
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
|
||||||
|
model_path = ckpt_with_model
|
||||||
|
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
|
||||||
|
|
||||||
|
# 4) Infer vocab_size from checkpoint
|
||||||
|
state_preview = torch.load(model_path, map_location='cpu')
|
||||||
|
if 'wte.weight' in state_preview:
|
||||||
|
vocab_size = state_preview['wte.weight'].shape[0]
|
||||||
|
elif 'head.weight' in state_preview:
|
||||||
|
vocab_size = state_preview['head.weight'].shape[0]
|
||||||
|
else:
|
||||||
|
candidate = None
|
||||||
|
for k, v in state_preview.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
||||||
|
V = max(v.shape)
|
||||||
|
if candidate is None or V > candidate:
|
||||||
|
candidate = V
|
||||||
|
if candidate is None:
|
||||||
|
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
|
||||||
|
vocab_size = candidate
|
||||||
|
|
||||||
|
# 5) Build model from config (tolerant to missing fields)
|
||||||
pdrop = getattr(config, 'pdrop', 0.1)
|
pdrop = getattr(config, 'pdrop', 0.1)
|
||||||
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
||||||
|
|
||||||
@@ -191,7 +206,7 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No
|
|||||||
token_pdrop=token_pdrop,
|
token_pdrop=token_pdrop,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# 5) Load weights
|
# 6) Load weights
|
||||||
state_dict = torch.load(model_path, map_location=device)
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user