import torch import numpy as np import random from collections import defaultdict class PatientEventDataset(torch.utils.data.Dataset): """ A PyTorch Dataset for handling temporal sequences of patient events. This class processes a raw NumPy array of patient records, groups them by patient ID, and prepares them for training by imputing gaps, padding, or truncating sequences to a fixed length. """ def __init__(self, data: np.ndarray, block_length: int): """ Initializes the dataset by pre-processing the patient event data. Args: data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32. The columns represent (patient_id, time_in_days, event_code). block_length (int): The fixed length for the output sequences. """ self.block_length = block_length # Group (time_in_days, event_code) pairs by patient_id. # This pre-processing step allows for efficient lookups in __getitem__. patient_events = defaultdict(list) for patient_id, time, event in data: patient_events[patient_id].append((time, event)) # Store a list of unique patient_ids to map indices to patients. self.patient_ids = list(patient_events.keys()) self.patient_events = dict(patient_events) def __len__(self) -> int: """ Returns the total number of unique patients in the dataset. """ return len(self.patient_ids) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Retrieves, processes, and returns a single patient's event sequence. Args: idx (int): The index of the patient to retrieve. Returns: A tuple of two torch.long tensors: (event_sequence, time_sequence), both of shape (block_length,). """ # 1. Retrieve and Sort patient_id = self.patient_ids[idx] records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) # 2. Impute "No Event" Gaps imputed_sequence = [] if not records: # Handle cases with no records for a patient if necessary, though # the constructor logic would typically prevent this. pass else: imputed_sequence.append(records[0]) for i in range(len(records) - 1): prev_time, _ = records[i] next_time, _ = records[i+1] time_gap = next_time - prev_time # If the gap is 5 years (1826 days) or more, insert "no event" records. if time_gap >= 1826: num_no_event_intervals = time_gap // 1826 for j in range(1, num_no_event_intervals + 1): no_event_time = prev_time + j * 1826 imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event" imputed_sequence.append(records[i+1]) # 3. Adjust Sequence Length seq_len = len(imputed_sequence) if seq_len > self.block_length: # If longer, randomly select a contiguous sub-sequence. start_index = random.randint(0, seq_len - self.block_length) final_sequence = imputed_sequence[start_index : start_index + self.block_length] elif seq_len < self.block_length: # If shorter, pad the sequence at the end. padding_needed = self.block_length - seq_len # Use event_code=0 and time_in_days=36525 for padding. padding = [(36525, 0)] * padding_needed final_sequence = imputed_sequence + padding else: # If equal, use the sequence as is. final_sequence = imputed_sequence # 4. Return Tensors # Separate the sequence into event codes and time, then convert to tensors. event_codes = [item[1] for item in final_sequence] time_stamps = [item[0] for item in final_sequence] event_tensor = torch.tensor(event_codes, dtype=torch.long) time_tensor = torch.tensor(time_stamps, dtype=torch.long) return event_tensor, time_tensor