Files
DeepHealth/utils.py

241 lines
9.5 KiB
Python
Raw Normal View History

import torch
import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from typing import Optional
class PatientEventDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset for handling temporal sequences of patient events.
This class processes a raw NumPy array of patient records, groups them by
patient ID, and prepares them for training by imputing gaps, padding, or
truncating sequences to a fixed length.
"""
def __init__(self, data: np.ndarray, block_length: int):
"""
Initializes the dataset by pre-processing the patient event data.
Args:
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
The columns represent (patient_id, time_in_days, event_code).
block_length (int): The fixed length for the output sequences.
"""
self.block_length = block_length
# Group (time_in_days, event_code) pairs by patient_id.
# This pre-processing step allows for efficient lookups in __getitem__.
patient_events = defaultdict(list)
for patient_id, time, event in data:
patient_events[patient_id].append((time, event))
# Store a list of unique patient_ids to map indices to patients.
self.patient_ids = list(patient_events.keys())
self.patient_events = dict(patient_events)
def __len__(self) -> int:
"""
Returns the total number of unique patients in the dataset.
"""
return len(self.patient_ids)
2025-10-20 13:47:50 +08:00
def __getitem__(self, idx):
"""
2025-10-20 13:47:50 +08:00
Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args:
2025-10-20 13:47:50 +08:00
idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns:
2025-10-20 13:47:50 +08:00
If idx is an int, a tuple of two torch.long tensors:
(event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
"""
2025-10-20 13:47:50 +08:00
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
# 2. Impute "No Event" Gaps
imputed_sequence = []
if not records:
# Handle cases with no records for a patient if necessary, though
# the constructor logic would typically prevent this.
pass
else:
imputed_sequence.append(records[0])
for i in range(len(records) - 1):
prev_time, _ = records[i]
next_time, _ = records[i+1]
time_gap = next_time - prev_time
# If the gap is 5 years (1826 days) or more, insert "no event" records.
if time_gap >= 1826:
num_no_event_intervals = time_gap // 1826
for j in range(1, num_no_event_intervals + 1):
no_event_time = prev_time + j * 1826
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
imputed_sequence.append(records[i+1])
# 3. Adjust Sequence Length
seq_len = len(imputed_sequence)
if seq_len > self.block_length:
# If longer, randomly select a contiguous sub-sequence.
start_index = random.randint(0, seq_len - self.block_length)
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
elif seq_len < self.block_length:
# If shorter, pad the sequence at the end.
padding_needed = self.block_length - seq_len
# Use event_code=0 and time_in_days=36525 for padding.
padding = [(36525, 0)] * padding_needed
final_sequence = imputed_sequence + padding
else:
# If equal, use the sequence as is.
final_sequence = imputed_sequence
# 4. Return Tensors
# Separate the sequence into event codes and time, then convert to tensors.
event_codes = [item[1] for item in final_sequence]
time_stamps = [item[0] for item in final_sequence]
event_tensor = torch.tensor(event_codes, dtype=torch.long)
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor
def load_model(config_path: str, model_path: str, vocab_size: Optional[int] = None, device: str = 'cpu'):
"""
Load a trained model based on the training configuration and checkpoint.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers vocab_size from the checkpoint if not provided
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
model_path: Path to the saved model state dict (.pt).
vocab_size: Optional. If not provided, inferred from checkpoint weight shapes.
device: 'cpu' or 'cuda'.
Returns:
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer vocab_size from checkpoint if not provided
if vocab_size is None:
state = torch.load(model_path, map_location='cpu')
# Try typical parameter names first
if 'wte.weight' in state:
vocab_size = state['wte.weight'].shape[0]
elif 'head.weight' in state:
vocab_size = state['head.weight'].shape[0]
else:
# Fallback: try to find any (V, D) weight that likely encodes vocab
candidate = None
for k, v in state.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
# Heuristic: the larger dim is probably vocab
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Please pass vocab_size explicitly.")
vocab_size = candidate
# 4) Build model from config
# Be tolerant to configs from earlier runs that did not save some fields
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# 5) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
model.eval()
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model
2025-10-20 13:47:50 +08:00
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times