import torch from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence import pandas as pd import numpy as np from collections import defaultdict from typing import Any, Dict, List, Sequence, Tuple class HealthDataset(Dataset): """ Dataset for health records. Args: data_prefix (str): Prefix for data files. covariate_list (List[str] | None): List of covariates to include. """ def __init__( self, data_prefix: str, max_len: int = 64, covariate_list: List[str] | None = None, ): event_data = np.load(f"{data_prefix}_event_data.npy") basic_info = pd.read_csv( f"{data_prefix}_basic_info.csv", index_col='eid') cont_cols = [] cate_cols = [] self.cate_dims = [] self.n_cont = 0 self.n_cate = 0 if covariate_list is not None: tabular_data = pd.read_csv( f"{data_prefix}_table.csv", index_col='eid')[covariate_list] tabular_data = tabular_data.convert_dtypes() for col in tabular_data.columns: if pd.api.types.is_float_dtype(tabular_data[col]): cont_cols.append(col) elif pd.api.types.is_integer_dtype(tabular_data[col]): series = tabular_data[col] unique_vals = series.dropna().unique() if len(unique_vals) > 11: cont_cols.append(col) else: cate_cols.append(col) self.cate_dims.append(int(series.max()) + 1) self.cont_features = tabular_data[cont_cols].to_numpy( dtype=np.float32).copy() self.cate_features = tabular_data[cate_cols].to_numpy( dtype=np.int64).copy() self.n_cont = self.cont_features.shape[1] self.n_cate = self.cate_features.shape[1] self.cont_features = torch.from_numpy(self.cont_features) self.cate_features = torch.from_numpy(self.cate_features) else: tabular_data = None self.cont_features = None self.cate_features = None patient_events = defaultdict(list) vocab_size = 0 for patient_id, time_in_days, event_code in event_data: patient_events[patient_id].append((time_in_days, event_code)) if event_code > vocab_size: vocab_size = event_code self.n_event = vocab_size - 2 # Exclude padding and CLS codes self.n_disease = self.n_event - 1 # Exclude death code self.max_len = max_len self.death_token = vocab_size self.basic_info = basic_info.convert_dtypes() self.patient_ids = self.basic_info.index.tolist() self.patient_events = dict(patient_events) for patient_id, records in self.patient_events.items(): records.sort(key=lambda x: x[0]) self._doa = self.basic_info.loc[ self.patient_ids, 'date_of_assessment' ].to_numpy(dtype=np.float32) self._sex = self.basic_info.loc[ self.patient_ids, 'sex' ].to_numpy(dtype=np.int64) def __len__(self) -> int: return len(self.patient_ids) def __getitem__(self, idx: int): patient_id = self.patient_ids[idx] records = list(self.patient_events.get(patient_id, [])) if self.cont_features is not None or self.cate_features is not None: doa = float(self._doa[idx]) # insert checkup event at date of assessment records.append((doa, 2)) # 1 is checkup token records.sort(key=lambda x: x[0]) n_events = len(records) has_died = (records[-1][1] == self.death_token) if n_events > 0 else False if has_died: valid_max_split = n_events - 1 else: valid_max_split = n_events if valid_max_split < 1: split_idx = 0 else: split_idx = np.random.randint(1, valid_max_split + 1) past_records = records[:split_idx] future_records = records[split_idx:] if len(past_records) > 0: current_time = past_records[-1][0] else: current_time = 0 if len(past_records) > self.max_len: past_records = past_records[-self.max_len:] event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token time_seq = [current_time] + [item[0] for item in past_records] event_tensor = torch.tensor(event_seq, dtype=torch.long) time_tensor = torch.tensor(time_seq, dtype=torch.float) labels = torch.zeros(self.n_disease, 2, dtype=torch.float32) global_censoring_time = records[-1][0] - current_time labels[:, 1] = global_censoring_time loss_mask = torch.ones(self.n_disease, dtype=torch.float32) death_time = 0.0 for t, token in past_records: if token > 2: disease_idx = token - 3 if 0 <= disease_idx < self.n_disease: loss_mask[disease_idx] = 0.0 for t, token in future_records: rel_time = t - current_time if rel_time < 1e-6: rel_time = 1e-6 if token == self.death_token: death_time = rel_time break if token > 2: disease_idx = token - 3 if 0 <= disease_idx < self.n_disease: if labels[disease_idx, 0] == 0.0: labels[disease_idx, 0] = 1.0 labels[disease_idx, 1] = rel_time if self.cont_features is not None: # 提取该病人的静态特征向量 cur_cont = self.cont_features[idx].to( dtype=torch.float) # (n_cont,) cur_cate = self.cate_features[idx].to( dtype=torch.long) # (n_cate,) else: cur_cont = None cur_cate = None sex = int(self._sex[idx]) return { 'event_seq': event_tensor, 'time_seq': time_tensor, 'labels': labels, 'loss_mask': loss_mask, 'sex': sex, 'death_time': death_time, 'global_censoring_time': global_censoring_time, 'cont_features': cur_cont, 'cate_features': cur_cate, } def health_collate_fn(batch): # 1) pad variable-length sequences event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor event_seq_padded = pad_sequence( event_seqs, batch_first=True, padding_value=0).long() time_seq_padded = pad_sequence( time_seqs, batch_first=True, padding_value=0.0).float() attn_mask = (event_seq_padded != 0) # True for real tokens # 2) stack fixed-size tensors labels = torch.stack([item['labels'] for item in batch], dim=0).float() # (B, n_disease, 2) loss_mask = torch.stack([item['loss_mask'] # (B, n_disease) for item in batch], dim=0).float() # 3) scalar fields -> tensor sex = torch.tensor([int(item['sex']) for item in batch], dtype=torch.long) # (B,) death_time = torch.tensor([float(item['death_time']) for item in batch], dtype=torch.float32) # (B,) global_censoring_time = torch.tensor( [float(item['global_censoring_time']) for item in batch], dtype=torch.float32 ) # (B,) # 4) optional static tabular features cont_list = [item['cont_features'] for item in batch] cate_list = [item['cate_features'] for item in batch] has_cont = any(x is not None for x in cont_list) has_cate = any(x is not None for x in cate_list) if has_cont and any(x is None for x in cont_list): raise ValueError("Mixed None/non-None cont_features in batch.") if has_cate and any(x is None for x in cate_list): raise ValueError("Mixed None/non-None cate_features in batch.") cont_features = torch.stack(cont_list, dim=0).float( ) if has_cont else None # (B, n_cont) cate_features = torch.stack(cate_list, dim=0).long( ) if has_cate else None # (B, n_cate) return { 'event_seq': event_seq_padded, 'time_seq': time_seq_padded, 'attn_mask': attn_mask, 'labels': labels, 'loss_mask': loss_mask, 'sex': sex, 'death_time': death_time, 'global_censoring_time': global_censoring_time, 'cont_features': cont_features, 'cate_features': cate_features, }