diff --git a/utils.py b/utils.py index ef25e81..5402dd8 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,8 @@ import numpy as np import random from collections import defaultdict import json -from models import TimeAwareGPT2 +from models import TimeAwareGPT2, TimeAwareGPT2Learnable +from typing import Optional class PatientEventDataset(torch.utils.data.Dataset): @@ -111,49 +112,100 @@ class PatientEventDataset(torch.utils.data.Dataset): return event_tensor, time_tensor -def load_model(config_path, model_path, vocab_size, device='cpu'): +def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'): """ - Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary. + Load a trained model based on the training configuration and checkpoint. + + According to train.py, models may be either 'TimeAwareGPT2' or + 'TimeAwareGPT2Learnable'. This function: + - Reads the config JSON to get architecture hyperparameters + - Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent) + - Infers vocab_size from the checkpoint if not provided + - Loads weights and returns the model in eval mode on the requested device Args: - config_path (str): Path to the JSON configuration file. - model_path (str): Path to the saved model state dictionary (.pt file). - vocab_size (int): The vocabulary size used during training. - device (str): The device to load the model onto ('cpu' or 'cuda'). + config_path: Path to the JSON configuration file saved during training. + model_path: Path to the saved model state dict (.pt). + vocab_size: Optional. If not provided, inferred from checkpoint weight shapes. + device: 'cpu' or 'cuda'. Returns: - (TimeAwareGPT2): The loaded and initialized model. + torch.nn.Module: Loaded model ready for inference. """ + # 1) Read config with open(config_path, 'r') as f: config_dict = json.load(f) - print(f"Model config: {config_dict}") - - # Create a config object from the dictionary + # Access config entries with attribute-style access while staying tolerant to missing keys class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(item) config = AttrDict(config_dict) - # Initialize the model with parameters from the config - model = TimeAwareGPT2( + # 2) Decide model class (train.py supports two variants) + model_name = getattr(config, 'model_name', 'TimeAwareGPT2') + model_cls = { + 'TimeAwareGPT2': TimeAwareGPT2, + 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, + }.get(model_name, TimeAwareGPT2) + + # 3) Infer vocab_size from checkpoint if not provided + if vocab_size is None: + state = torch.load(model_path, map_location='cpu') + # Try typical parameter names first + if 'wte.weight' in state: + vocab_size = state['wte.weight'].shape[0] + elif 'head.weight' in state: + vocab_size = state['head.weight'].shape[0] + else: + # Fallback: try to find any (V, D) weight that likely encodes vocab + candidate = None + for k, v in state.items(): + if isinstance(v, torch.Tensor) and v.ndim == 2: + # Heuristic: the larger dim is probably vocab + V = max(v.shape) + if candidate is None or V > candidate: + candidate = V + if candidate is None: + raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.") + vocab_size = candidate + + # 4) Build model from config + # Be tolerant to configs from earlier runs that did not save some fields + n_embd = getattr(config, 'n_embd') + n_layer = getattr(config, 'n_layer') + n_head = getattr(config, 'n_head') + pdrop = getattr(config, 'pdrop', 0.1) + token_pdrop = getattr(config, 'token_pdrop', 0.1) + + model = model_cls( vocab_size=vocab_size, - n_embd=config.n_embd, - n_layer=config.n_layer, - n_head=config.n_head, - pdrop=config.pdrop, - token_pdrop=config.token_pdrop + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + pdrop=pdrop, + token_pdrop=token_pdrop, ).to(device) - # Load the saved state dictionary - model.load_state_dict(torch.load(model_path, map_location=device)) + # 5) Load weights + state_dict = torch.load(model_path, map_location=device) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if missing: + print(f"Warning: Missing keys when loading state_dict: {missing}") + if unexpected: + print(f"Warning: Unexpected keys when loading state_dict: {unexpected}") - # Set the model to evaluation mode model.eval() - - print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") + try: + num_params_m = model.get_num_params() + print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.") + except Exception: + pass return model