Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly
This commit is contained in:
		
							
								
								
									
										104
									
								
								utils.py
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								utils.py
									
									
									
									
									
								
							| @@ -3,7 +3,8 @@ import numpy as np | ||||
| import random | ||||
| from collections import defaultdict | ||||
| import json | ||||
| from models import TimeAwareGPT2 | ||||
| from models import TimeAwareGPT2, TimeAwareGPT2Learnable | ||||
| from typing import Optional | ||||
|  | ||||
|  | ||||
| class PatientEventDataset(torch.utils.data.Dataset): | ||||
| @@ -111,49 +112,100 @@ class PatientEventDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|         return event_tensor, time_tensor | ||||
|  | ||||
| def load_model(config_path, model_path, vocab_size, device='cpu'): | ||||
| def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'): | ||||
|     """ | ||||
|     Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary. | ||||
|     Load a trained model based on the training configuration and checkpoint. | ||||
|  | ||||
|     According to train.py, models may be either 'TimeAwareGPT2' or | ||||
|     'TimeAwareGPT2Learnable'. This function: | ||||
|       - Reads the config JSON to get architecture hyperparameters | ||||
|       - Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent) | ||||
|       - Infers vocab_size from the checkpoint if not provided | ||||
|       - Loads weights and returns the model in eval mode on the requested device | ||||
|  | ||||
|     Args: | ||||
|         config_path (str): Path to the JSON configuration file. | ||||
|         model_path (str): Path to the saved model state dictionary (.pt file). | ||||
|         vocab_size (int): The vocabulary size used during training. | ||||
|         device (str): The device to load the model onto ('cpu' or 'cuda'). | ||||
|         config_path: Path to the JSON configuration file saved during training. | ||||
|         model_path: Path to the saved model state dict (.pt). | ||||
|         vocab_size: Optional. If not provided, inferred from checkpoint weight shapes. | ||||
|         device: 'cpu' or 'cuda'. | ||||
|  | ||||
|     Returns: | ||||
|         (TimeAwareGPT2): The loaded and initialized model. | ||||
|         torch.nn.Module: Loaded model ready for inference. | ||||
|     """ | ||||
|     # 1) Read config | ||||
|     with open(config_path, 'r') as f: | ||||
|         config_dict = json.load(f) | ||||
|  | ||||
|     print(f"Model config: {config_dict}") | ||||
|  | ||||
|     # Create a config object from the dictionary | ||||
|     # Access config entries with attribute-style access while staying tolerant to missing keys | ||||
|     class AttrDict(dict): | ||||
|         def __init__(self, *args, **kwargs): | ||||
|             super(AttrDict, self).__init__(*args, **kwargs) | ||||
|             self.__dict__ = self | ||||
|         def __getattr__(self, item): | ||||
|             try: | ||||
|                 return self[item] | ||||
|             except KeyError: | ||||
|                 raise AttributeError(item) | ||||
|  | ||||
|     config = AttrDict(config_dict) | ||||
|  | ||||
|     # Initialize the model with parameters from the config | ||||
|     model = TimeAwareGPT2( | ||||
|     # 2) Decide model class (train.py supports two variants) | ||||
|     model_name = getattr(config, 'model_name', 'TimeAwareGPT2') | ||||
|     model_cls = { | ||||
|         'TimeAwareGPT2': TimeAwareGPT2, | ||||
|         'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, | ||||
|     }.get(model_name, TimeAwareGPT2) | ||||
|  | ||||
|     # 3) Infer vocab_size from checkpoint if not provided | ||||
|     if vocab_size is None: | ||||
|         state = torch.load(model_path, map_location='cpu') | ||||
|         # Try typical parameter names first | ||||
|         if 'wte.weight' in state: | ||||
|             vocab_size = state['wte.weight'].shape[0] | ||||
|         elif 'head.weight' in state: | ||||
|             vocab_size = state['head.weight'].shape[0] | ||||
|         else: | ||||
|             # Fallback: try to find any (V, D) weight that likely encodes vocab | ||||
|             candidate = None | ||||
|             for k, v in state.items(): | ||||
|                 if isinstance(v, torch.Tensor) and v.ndim == 2: | ||||
|                     # Heuristic: the larger dim is probably vocab | ||||
|                     V = max(v.shape) | ||||
|                     if candidate is None or V > candidate: | ||||
|                         candidate = V | ||||
|             if candidate is None: | ||||
|                 raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.") | ||||
|             vocab_size = candidate | ||||
|  | ||||
|     # 4) Build model from config | ||||
|     # Be tolerant to configs from earlier runs that did not save some fields | ||||
|     n_embd = getattr(config, 'n_embd') | ||||
|     n_layer = getattr(config, 'n_layer') | ||||
|     n_head = getattr(config, 'n_head') | ||||
|     pdrop = getattr(config, 'pdrop', 0.1) | ||||
|     token_pdrop = getattr(config, 'token_pdrop', 0.1) | ||||
|  | ||||
|     model = model_cls( | ||||
|         vocab_size=vocab_size, | ||||
|         n_embd=config.n_embd, | ||||
|         n_layer=config.n_layer, | ||||
|         n_head=config.n_head, | ||||
|         pdrop=config.pdrop, | ||||
|         token_pdrop=config.token_pdrop | ||||
|         n_embd=n_embd, | ||||
|         n_layer=n_layer, | ||||
|         n_head=n_head, | ||||
|         pdrop=pdrop, | ||||
|         token_pdrop=token_pdrop, | ||||
|     ).to(device) | ||||
|  | ||||
|     # Load the saved state dictionary | ||||
|     model.load_state_dict(torch.load(model_path, map_location=device)) | ||||
|     # 5) Load weights | ||||
|     state_dict = torch.load(model_path, map_location=device) | ||||
|     missing, unexpected = model.load_state_dict(state_dict, strict=False) | ||||
|  | ||||
|     if missing: | ||||
|         print(f"Warning: Missing keys when loading state_dict: {missing}") | ||||
|     if unexpected: | ||||
|         print(f"Warning: Unexpected keys when loading state_dict: {unexpected}") | ||||
|  | ||||
|     # Set the model to evaluation mode | ||||
|     model.eval() | ||||
|  | ||||
|     print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") | ||||
|     try: | ||||
|         num_params_m = model.get_num_params() | ||||
|         print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.") | ||||
|     except Exception: | ||||
|         pass | ||||
|     return model | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user