From 6801e5bdbb26e11809b9fc3dedbaa35f98f0074b Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Wed, 22 Oct 2025 11:39:10 +0800 Subject: [PATCH] utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint --- utils.py | 75 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/utils.py b/utils.py index 5402dd8..f63b37d 100644 --- a/utils.py +++ b/utils.py @@ -1,10 +1,10 @@ +import os import torch import numpy as np import random from collections import defaultdict import json from models import TimeAwareGPT2, TimeAwareGPT2Learnable -from typing import Optional class PatientEventDataset(torch.utils.data.Dataset): @@ -112,21 +112,21 @@ class PatientEventDataset(torch.utils.data.Dataset): return event_tensor, time_tensor -def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'): +def load_model(config_path: str, device: str = 'cpu'): """ - Load a trained model based on the training configuration and checkpoint. + Load a trained model based on the training configuration, inferring the + checkpoint filename from the configuration. According to train.py, models may be either 'TimeAwareGPT2' or 'TimeAwareGPT2Learnable'. This function: - Reads the config JSON to get architecture hyperparameters - Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent) - - Infers vocab_size from the checkpoint if not provided + - Infers the checkpoint path from the config values + - Infers vocab_size from the checkpoint - Loads weights and returns the model in eval mode on the requested device Args: config_path: Path to the JSON configuration file saved during training. - model_path: Path to the saved model state dict (.pt). - vocab_size: Optional. If not provided, inferred from checkpoint weight shapes. device: 'cpu' or 'cuda'. Returns: @@ -153,32 +153,47 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, }.get(model_name, TimeAwareGPT2) - # 3) Infer vocab_size from checkpoint if not provided - if vocab_size is None: - state = torch.load(model_path, map_location='cpu') - # Try typical parameter names first - if 'wte.weight' in state: - vocab_size = state['wte.weight'].shape[0] - elif 'head.weight' in state: - vocab_size = state['head.weight'].shape[0] - else: - # Fallback: try to find any (V, D) weight that likely encodes vocab - candidate = None - for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.ndim == 2: - # Heuristic: the larger dim is probably vocab - V = max(v.shape) - if candidate is None or V > candidate: - candidate = V - if candidate is None: - raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.") - vocab_size = candidate - - # 4) Build model from config - # Be tolerant to configs from earlier runs that did not save some fields + # 3) Infer checkpoint filename from config n_embd = getattr(config, 'n_embd') n_layer = getattr(config, 'n_layer') n_head = getattr(config, 'n_head') + + # Newer naming (includes model_name) used by train.py when model_name is present + suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}" + ckpt_with_model = f"best_model_{suffix_with_model}.pt" + + # Older naming (without model_name) matches existing repo files + suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}" + ckpt_legacy = f"best_model_{suffix_legacy}.pt" + + # Prefer file that exists on disk + if os.path.exists(ckpt_with_model): + model_path = ckpt_with_model + elif os.path.exists(ckpt_legacy): + model_path = ckpt_legacy + else: + # Fall back to including model_name; if not present in config earlier, user may still have saved this way + model_path = ckpt_with_model + print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}") + + # 4) Infer vocab_size from checkpoint + state_preview = torch.load(model_path, map_location='cpu') + if 'wte.weight' in state_preview: + vocab_size = state_preview['wte.weight'].shape[0] + elif 'head.weight' in state_preview: + vocab_size = state_preview['head.weight'].shape[0] + else: + candidate = None + for k, v in state_preview.items(): + if isinstance(v, torch.Tensor) and v.ndim == 2: + V = max(v.shape) + if candidate is None or V > candidate: + candidate = V + if candidate is None: + raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.") + vocab_size = candidate + + # 5) Build model from config (tolerant to missing fields) pdrop = getattr(config, 'pdrop', 0.1) token_pdrop = getattr(config, 'token_pdrop', 0.1) @@ -191,7 +206,7 @@ def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = No token_pdrop=token_pdrop, ).to(device) - # 5) Load weights + # 6) Load weights state_dict = torch.load(model_path, map_location=device) missing, unexpected = model.load_state_dict(state_dict, strict=False)