import torch from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence import pandas as pd import numpy as np from collections import defaultdict from typing import List class HealthDataset(Dataset): """ Dataset for health records. Args: data_prefix (str): Prefix for data files. covariate_list (List[str] | None): List of covariates to include. """ def __init__( self, data_prefix: str, covariate_list: List[str] | None = None, cache_event_tensors: bool = True, ): basic_info = pd.read_csv( f"{data_prefix}_basic_info.csv", index_col='eid') tabular_data = pd.read_csv(f"{data_prefix}_table.csv", index_col='eid') event_data = np.load(f"{data_prefix}_event_data.npy") patient_events = defaultdict(list) vocab_size = 0 for patient_id, time_in_days, event_code in event_data: patient_events[patient_id].append((time_in_days, event_code)) if event_code > vocab_size: vocab_size = event_code self.n_disease = vocab_size - 1 self.basic_info = basic_info.convert_dtypes() self.patient_ids = self.basic_info.index.tolist() self.patient_events = dict(patient_events) for patient_id, records in self.patient_events.items(): records.sort(key=lambda x: x[0]) tabular_data = tabular_data.convert_dtypes() cont_cols = [] cate_cols = [] self.cate_dims = [] if covariate_list is not None: tabular_data = tabular_data[covariate_list] for col in tabular_data.columns: if pd.api.types.is_float_dtype(tabular_data[col]): cont_cols.append(col) elif pd.api.types.is_integer_dtype(tabular_data[col]): series = tabular_data[col] unique_vals = series.dropna().unique() if len(unique_vals) > 11: cont_cols.append(col) else: cate_cols.append(col) self.cate_dims.append(int(series.max()) + 1) self.cont_features = tabular_data[cont_cols].to_numpy( dtype=np.float32).copy() self.cate_features = tabular_data[cate_cols].to_numpy( dtype=np.int64).copy() self.n_cont = self.cont_features.shape[1] self.n_cate = self.cate_features.shape[1] self._doa = self.basic_info.loc[ self.patient_ids, 'date_of_assessment' ].to_numpy(dtype=np.float32) self._sex = self.basic_info.loc[ self.patient_ids, 'sex' ].to_numpy(dtype=np.int64) self.cont_features = torch.from_numpy(self.cont_features) self.cate_features = torch.from_numpy(self.cate_features) # Optional cache for the DOA-inserted sequences produced by __getitem__. # This preserves outputs exactly (we reuse the same construction logic), # but avoids re-building Python lists on repeated access. self._cache_event_tensors = bool(cache_event_tensors) self._cached_event_tensors: List[torch.Tensor | None] = [None] * len( self.patient_ids ) self._cached_time_tensors: List[torch.Tensor | None] = [None] * len( self.patient_ids ) def __len__(self) -> int: return len(self.patient_ids) def __getitem__(self, idx): if self._cache_event_tensors: cached_e = self._cached_event_tensors[idx] cached_t = self._cached_time_tensors[idx] if cached_e is not None and cached_t is not None: event_tensor = cached_e time_tensor = cached_t else: patient_id = self.patient_ids[idx] records = self.patient_events.get(patient_id, []) event_seq = [item[1] for item in records] time_seq = [item[0] for item in records] doa = float(self._doa[idx]) insert_pos = np.searchsorted(time_seq, doa) time_seq.insert(insert_pos, doa) # assuming 1 is the code for 'DOA' event event_seq.insert(insert_pos, 1) event_tensor = torch.tensor(event_seq, dtype=torch.long) time_tensor = torch.tensor(time_seq, dtype=torch.float) self._cached_event_tensors[idx] = event_tensor self._cached_time_tensors[idx] = time_tensor else: patient_id = self.patient_ids[idx] records = self.patient_events.get(patient_id, []) event_seq = [item[1] for item in records] time_seq = [item[0] for item in records] doa = float(self._doa[idx]) insert_pos = np.searchsorted(time_seq, doa) time_seq.insert(insert_pos, doa) # assuming 1 is the code for 'DOA' event event_seq.insert(insert_pos, 1) event_tensor = torch.tensor(event_seq, dtype=torch.long) time_tensor = torch.tensor(time_seq, dtype=torch.float) cont_tensor = self.cont_features[idx, :].to(dtype=torch.float) cate_tensor = self.cate_features[idx, :].to(dtype=torch.long) sex = int(self._sex[idx]) return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex) def health_collate_fn(batch): event_seqs, time_seqs, cont_feats, cate_feats, sexes = zip(*batch) event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0) time_batch = pad_sequence( time_seqs, batch_first=True, padding_value=36525.0) cont_batch = torch.stack(cont_feats, dim=0) cont_batch = cont_batch.unsqueeze(1) # (B, 1, n_cont) cate_batch = torch.stack(cate_feats, dim=0) cate_batch = cate_batch.unsqueeze(1) # (B, 1, n_cate) sex_batch = torch.tensor(sexes, dtype=torch.long) return event_batch, time_batch, cont_batch, cate_batch, sex_batch