Files
DeepHealth/utils.py

256 lines
10 KiB
Python

import os
import torch
import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
class PatientEventDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset for handling temporal sequences of patient events.
This class processes a raw NumPy array of patient records, groups them by
patient ID, and prepares them for training by imputing gaps, padding, or
truncating sequences to a fixed length.
"""
def __init__(self, data: np.ndarray, block_length: int):
"""
Initializes the dataset by pre-processing the patient event data.
Args:
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
The columns represent (patient_id, time_in_days, event_code).
block_length (int): The fixed length for the output sequences.
"""
self.block_length = block_length
# Group (time_in_days, event_code) pairs by patient_id.
# This pre-processing step allows for efficient lookups in __getitem__.
patient_events = defaultdict(list)
for patient_id, time, event in data:
patient_events[patient_id].append((time, event))
# Store a list of unique patient_ids to map indices to patients.
self.patient_ids = list(patient_events.keys())
self.patient_events = dict(patient_events)
def __len__(self) -> int:
"""
Returns the total number of unique patients in the dataset.
"""
return len(self.patient_ids)
def __getitem__(self, idx):
"""
Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args:
idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns:
If idx is an int, a tuple of two torch.long tensors:
(event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
"""
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
# 2. Impute "No Event" Gaps
imputed_sequence = []
if not records:
# Handle cases with no records for a patient if necessary, though
# the constructor logic would typically prevent this.
pass
else:
imputed_sequence.append(records[0])
for i in range(len(records) - 1):
prev_time, _ = records[i]
next_time, _ = records[i+1]
time_gap = next_time - prev_time
# If the gap is 5 years (1826 days) or more, insert "no event" records.
if time_gap >= 1826:
num_no_event_intervals = time_gap // 1826
for j in range(1, num_no_event_intervals + 1):
no_event_time = prev_time + j * 1826
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
imputed_sequence.append(records[i+1])
# 3. Adjust Sequence Length
seq_len = len(imputed_sequence)
if seq_len > self.block_length:
# If longer, randomly select a contiguous sub-sequence.
start_index = random.randint(0, seq_len - self.block_length)
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
elif seq_len < self.block_length:
# If shorter, pad the sequence at the end.
padding_needed = self.block_length - seq_len
# Use event_code=0 and time_in_days=36525 for padding.
padding = [(36525, 0)] * padding_needed
final_sequence = imputed_sequence + padding
else:
# If equal, use the sequence as is.
final_sequence = imputed_sequence
# 4. Return Tensors
# Separate the sequence into event codes and time, then convert to tensors.
event_codes = [item[1] for item in final_sequence]
time_stamps = [item[0] for item in final_sequence]
event_tensor = torch.tensor(event_codes, dtype=torch.long)
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor
def load_model(config_path: str, device: str = 'cpu'):
"""
Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
device: 'cpu' or 'cuda'.
Returns:
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# Newer naming (includes model_name) used by train.py when model_name is present
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else:
candidate = None
for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate
# 5) Build model from config (tolerant to missing fields)
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# 6) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
model.eval()
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times