feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
104
utils.py
Normal file
104
utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
A PyTorch Dataset for handling temporal sequences of patient events.
|
||||
|
||||
This class processes a raw NumPy array of patient records, groups them by
|
||||
patient ID, and prepares them for training by imputing gaps, padding, or
|
||||
truncating sequences to a fixed length.
|
||||
"""
|
||||
|
||||
def __init__(self, data: np.ndarray, block_length: int):
|
||||
"""
|
||||
Initializes the dataset by pre-processing the patient event data.
|
||||
|
||||
Args:
|
||||
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
|
||||
The columns represent (patient_id, time_in_days, event_code).
|
||||
block_length (int): The fixed length for the output sequences.
|
||||
"""
|
||||
self.block_length = block_length
|
||||
|
||||
# Group (time_in_days, event_code) pairs by patient_id.
|
||||
# This pre-processing step allows for efficient lookups in __getitem__.
|
||||
patient_events = defaultdict(list)
|
||||
for patient_id, time, event in data:
|
||||
patient_events[patient_id].append((time, event))
|
||||
|
||||
# Store a list of unique patient_ids to map indices to patients.
|
||||
self.patient_ids = list(patient_events.keys())
|
||||
self.patient_events = dict(patient_events)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the total number of unique patients in the dataset.
|
||||
"""
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Retrieves, processes, and returns a single patient's event sequence.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the patient to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||
both of shape (block_length,).
|
||||
"""
|
||||
# 1. Retrieve and Sort
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||
|
||||
# 2. Impute "No Event" Gaps
|
||||
imputed_sequence = []
|
||||
if not records:
|
||||
# Handle cases with no records for a patient if necessary, though
|
||||
# the constructor logic would typically prevent this.
|
||||
pass
|
||||
else:
|
||||
imputed_sequence.append(records[0])
|
||||
for i in range(len(records) - 1):
|
||||
prev_time, _ = records[i]
|
||||
next_time, _ = records[i+1]
|
||||
time_gap = next_time - prev_time
|
||||
|
||||
# If the gap is 5 years (1826 days) or more, insert "no event" records.
|
||||
if time_gap >= 1826:
|
||||
num_no_event_intervals = time_gap // 1826
|
||||
for j in range(1, num_no_event_intervals + 1):
|
||||
no_event_time = prev_time + j * 1826
|
||||
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
|
||||
|
||||
imputed_sequence.append(records[i+1])
|
||||
|
||||
# 3. Adjust Sequence Length
|
||||
seq_len = len(imputed_sequence)
|
||||
|
||||
if seq_len > self.block_length:
|
||||
# If longer, randomly select a contiguous sub-sequence.
|
||||
start_index = random.randint(0, seq_len - self.block_length)
|
||||
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
|
||||
elif seq_len < self.block_length:
|
||||
# If shorter, pad the sequence at the end.
|
||||
padding_needed = self.block_length - seq_len
|
||||
# Use event_code=0 and time_in_days=36525 for padding.
|
||||
padding = [(36525, 0)] * padding_needed
|
||||
final_sequence = imputed_sequence + padding
|
||||
else:
|
||||
# If equal, use the sequence as is.
|
||||
final_sequence = imputed_sequence
|
||||
|
||||
# 4. Return Tensors
|
||||
# Separate the sequence into event codes and time, then convert to tensors.
|
||||
event_codes = [item[1] for item in final_sequence]
|
||||
time_stamps = [item[0] for item in final_sequence]
|
||||
|
||||
event_tensor = torch.tensor(event_codes, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||
|
||||
return event_tensor, time_tensor
|
Reference in New Issue
Block a user