2025-10-16 14:21:36 +08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
from collections import defaultdict
|
2025-10-18 11:07:59 +08:00
|
|
|
import json
|
|
|
|
from models import TimeAwareGPT2
|
|
|
|
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
class PatientEventDataset(torch.utils.data.Dataset):
|
|
|
|
"""
|
|
|
|
A PyTorch Dataset for handling temporal sequences of patient events.
|
|
|
|
|
|
|
|
This class processes a raw NumPy array of patient records, groups them by
|
|
|
|
patient ID, and prepares them for training by imputing gaps, padding, or
|
|
|
|
truncating sequences to a fixed length.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data: np.ndarray, block_length: int):
|
|
|
|
"""
|
|
|
|
Initializes the dataset by pre-processing the patient event data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
|
|
|
|
The columns represent (patient_id, time_in_days, event_code).
|
|
|
|
block_length (int): The fixed length for the output sequences.
|
|
|
|
"""
|
|
|
|
self.block_length = block_length
|
|
|
|
|
|
|
|
# Group (time_in_days, event_code) pairs by patient_id.
|
|
|
|
# This pre-processing step allows for efficient lookups in __getitem__.
|
|
|
|
patient_events = defaultdict(list)
|
|
|
|
for patient_id, time, event in data:
|
|
|
|
patient_events[patient_id].append((time, event))
|
|
|
|
|
|
|
|
# Store a list of unique patient_ids to map indices to patients.
|
|
|
|
self.patient_ids = list(patient_events.keys())
|
|
|
|
self.patient_events = dict(patient_events)
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
"""
|
|
|
|
Returns the total number of unique patients in the dataset.
|
|
|
|
"""
|
|
|
|
return len(self.patient_ids)
|
|
|
|
|
2025-10-20 13:47:50 +08:00
|
|
|
def __getitem__(self, idx):
|
2025-10-16 14:21:36 +08:00
|
|
|
"""
|
2025-10-20 13:47:50 +08:00
|
|
|
Retrieves, processes, and returns a single patient's event sequence,
|
|
|
|
or a list of sequences if a slice is provided.
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
Args:
|
2025-10-20 13:47:50 +08:00
|
|
|
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
Returns:
|
2025-10-20 13:47:50 +08:00
|
|
|
If idx is an int, a tuple of two torch.long tensors:
|
|
|
|
(event_sequence, time_sequence), both of shape (block_length,).
|
|
|
|
If idx is a slice, a list of such tuples.
|
2025-10-16 14:21:36 +08:00
|
|
|
"""
|
2025-10-20 13:47:50 +08:00
|
|
|
if isinstance(idx, slice):
|
|
|
|
return [self[i] for i in range(*idx.indices(len(self)))]
|
|
|
|
|
2025-10-16 14:21:36 +08:00
|
|
|
# 1. Retrieve and Sort
|
|
|
|
patient_id = self.patient_ids[idx]
|
|
|
|
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
|
|
|
|
|
|
|
# 2. Impute "No Event" Gaps
|
|
|
|
imputed_sequence = []
|
|
|
|
if not records:
|
|
|
|
# Handle cases with no records for a patient if necessary, though
|
|
|
|
# the constructor logic would typically prevent this.
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
imputed_sequence.append(records[0])
|
|
|
|
for i in range(len(records) - 1):
|
|
|
|
prev_time, _ = records[i]
|
|
|
|
next_time, _ = records[i+1]
|
|
|
|
time_gap = next_time - prev_time
|
|
|
|
|
|
|
|
# If the gap is 5 years (1826 days) or more, insert "no event" records.
|
|
|
|
if time_gap >= 1826:
|
|
|
|
num_no_event_intervals = time_gap // 1826
|
|
|
|
for j in range(1, num_no_event_intervals + 1):
|
|
|
|
no_event_time = prev_time + j * 1826
|
|
|
|
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
|
|
|
|
|
|
|
|
imputed_sequence.append(records[i+1])
|
|
|
|
|
|
|
|
# 3. Adjust Sequence Length
|
|
|
|
seq_len = len(imputed_sequence)
|
|
|
|
|
|
|
|
if seq_len > self.block_length:
|
|
|
|
# If longer, randomly select a contiguous sub-sequence.
|
|
|
|
start_index = random.randint(0, seq_len - self.block_length)
|
|
|
|
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
|
|
|
|
elif seq_len < self.block_length:
|
|
|
|
# If shorter, pad the sequence at the end.
|
|
|
|
padding_needed = self.block_length - seq_len
|
|
|
|
# Use event_code=0 and time_in_days=36525 for padding.
|
|
|
|
padding = [(36525, 0)] * padding_needed
|
|
|
|
final_sequence = imputed_sequence + padding
|
|
|
|
else:
|
|
|
|
# If equal, use the sequence as is.
|
|
|
|
final_sequence = imputed_sequence
|
|
|
|
|
|
|
|
# 4. Return Tensors
|
|
|
|
# Separate the sequence into event codes and time, then convert to tensors.
|
|
|
|
event_codes = [item[1] for item in final_sequence]
|
|
|
|
time_stamps = [item[0] for item in final_sequence]
|
|
|
|
|
|
|
|
event_tensor = torch.tensor(event_codes, dtype=torch.long)
|
|
|
|
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
|
|
|
|
|
|
|
return event_tensor, time_tensor
|
2025-10-18 11:07:59 +08:00
|
|
|
|
|
|
|
def load_model(config_path, model_path, vocab_size, device='cpu'):
|
|
|
|
"""
|
|
|
|
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config_path (str): Path to the JSON configuration file.
|
|
|
|
model_path (str): Path to the saved model state dictionary (.pt file).
|
|
|
|
vocab_size (int): The vocabulary size used during training.
|
|
|
|
device (str): The device to load the model onto ('cpu' or 'cuda').
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(TimeAwareGPT2): The loaded and initialized model.
|
|
|
|
"""
|
|
|
|
with open(config_path, 'r') as f:
|
|
|
|
config_dict = json.load(f)
|
|
|
|
|
2025-10-20 10:14:50 +08:00
|
|
|
print(f"Model config: {config_dict}")
|
|
|
|
|
2025-10-18 11:07:59 +08:00
|
|
|
# Create a config object from the dictionary
|
|
|
|
class AttrDict(dict):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
|
|
self.__dict__ = self
|
|
|
|
|
|
|
|
config = AttrDict(config_dict)
|
|
|
|
|
|
|
|
# Initialize the model with parameters from the config
|
|
|
|
model = TimeAwareGPT2(
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
n_embd=config.n_embd,
|
|
|
|
n_layer=config.n_layer,
|
|
|
|
n_head=config.n_head,
|
|
|
|
pdrop=config.pdrop,
|
|
|
|
token_pdrop=config.token_pdrop
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
# Load the saved state dictionary
|
|
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
|
|
|
|
# Set the model to evaluation mode
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
|
|
|
return model
|
2025-10-20 13:47:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
"""
|
|
|
|
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dataset (PatientEventDataset): The dataset to retrieve data from.
|
|
|
|
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
|
|
|
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
|
|
|
These tokens will be replaced with -100. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple containing four tensors:
|
|
|
|
- input_events: (batch_size, sequence_length - 1)
|
|
|
|
- input_tims: (batch_size, sequence_length - 1)
|
|
|
|
- target_events: (batch_size, sequence_length - 1)
|
|
|
|
- target_times: (batch_size, sequence_length - 1)
|
|
|
|
"""
|
|
|
|
batch_data = dataset[batch_slice]
|
|
|
|
|
|
|
|
input_events = [item[0][:-1] for item in batch_data]
|
|
|
|
input_tims = [item[1][:-1] for item in batch_data]
|
|
|
|
target_events = [item[0][1:] for item in batch_data]
|
|
|
|
target_times = [item[1][1:] for item in batch_data]
|
|
|
|
|
|
|
|
input_events = torch.stack(input_events)
|
|
|
|
input_tims = torch.stack(input_tims)
|
|
|
|
target_events = torch.stack(target_events)
|
|
|
|
target_times = torch.stack(target_times)
|
|
|
|
|
|
|
|
return input_events, input_tims, target_events, target_times
|