Files
DeepHealthV2/dataset.py

241 lines
8.6 KiB
Python
Raw Normal View History

import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import Any, Dict, List, Sequence, Tuple
class HealthDataset(Dataset):
"""
Dataset for health records.
Args:
data_prefix (str): Prefix for data files.
covariate_list (List[str] | None): List of covariates to include.
"""
def __init__(
self,
data_prefix: str,
max_len: int = 64,
covariate_list: List[str] | None = None,
):
event_data = np.load(f"{data_prefix}_event_data.npy")
basic_info = pd.read_csv(
f"{data_prefix}_basic_info.csv", index_col='eid')
cont_cols = []
cate_cols = []
self.cate_dims = []
self.n_cont = 0
self.n_cate = 0
if covariate_list is not None:
tabular_data = pd.read_csv(
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
tabular_data = tabular_data.convert_dtypes()
for col in tabular_data.columns:
if pd.api.types.is_float_dtype(tabular_data[col]):
cont_cols.append(col)
elif pd.api.types.is_integer_dtype(tabular_data[col]):
series = tabular_data[col]
unique_vals = series.dropna().unique()
if len(unique_vals) > 11:
cont_cols.append(col)
else:
cate_cols.append(col)
self.cate_dims.append(int(series.max()) + 1)
self.cont_features = tabular_data[cont_cols].to_numpy(
dtype=np.float32).copy()
self.cate_features = tabular_data[cate_cols].to_numpy(
dtype=np.int64).copy()
self.n_cont = self.cont_features.shape[1]
self.n_cate = self.cate_features.shape[1]
self.cont_features = torch.from_numpy(self.cont_features)
self.cate_features = torch.from_numpy(self.cate_features)
else:
tabular_data = None
self.cont_features = None
self.cate_features = None
patient_events = defaultdict(list)
vocab_size = 0
for patient_id, time_in_days, event_code in event_data:
patient_events[patient_id].append((time_in_days, event_code))
if event_code > vocab_size:
vocab_size = event_code
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
self.n_disease = self.n_event - 1 # Exclude death code
self.max_len = max_len
self.death_token = vocab_size
self.basic_info = basic_info.convert_dtypes()
self.patient_ids = self.basic_info.index.tolist()
self.patient_events = dict(patient_events)
for patient_id, records in self.patient_events.items():
records.sort(key=lambda x: x[0])
self._doa = self.basic_info.loc[
self.patient_ids, 'date_of_assessment'
].to_numpy(dtype=np.float32)
self._sex = self.basic_info.loc[
self.patient_ids, 'sex'
].to_numpy(dtype=np.int64)
def __len__(self) -> int:
return len(self.patient_ids)
def __getitem__(self, idx: int):
patient_id = self.patient_ids[idx]
records = list(self.patient_events.get(patient_id, []))
if self.cont_features is not None or self.cate_features is not None:
doa = float(self._doa[idx])
# insert checkup event at date of assessment
records.append((doa, 2)) # 1 is checkup token
records.sort(key=lambda x: x[0])
n_events = len(records)
has_died = (records[-1][1] ==
self.death_token) if n_events > 0 else False
if has_died:
valid_max_split = n_events - 1
else:
valid_max_split = n_events
if valid_max_split < 1:
split_idx = 0
else:
split_idx = np.random.randint(1, valid_max_split + 1)
past_records = records[:split_idx]
future_records = records[split_idx:]
if len(past_records) > 0:
current_time = past_records[-1][0]
else:
current_time = 0
if len(past_records) > self.max_len:
past_records = past_records[-self.max_len:]
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
time_seq = [current_time] + [item[0] for item in past_records]
event_tensor = torch.tensor(event_seq, dtype=torch.long)
time_tensor = torch.tensor(time_seq, dtype=torch.float)
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
global_censoring_time = records[-1][0] - current_time
labels[:, 1] = global_censoring_time
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
death_time = 0.0
for t, token in past_records:
if token > 2:
disease_idx = token - 3
if 0 <= disease_idx < self.n_disease:
loss_mask[disease_idx] = 0.0
for t, token in future_records:
rel_time = t - current_time
if rel_time < 1e-6:
rel_time = 1e-6
if token == self.death_token:
death_time = rel_time
break
if token > 2:
disease_idx = token - 3
if 0 <= disease_idx < self.n_disease:
if labels[disease_idx, 0] == 0.0:
labels[disease_idx, 0] = 1.0
labels[disease_idx, 1] = rel_time
if self.cont_features is not None:
# 提取该病人的静态特征向量
cur_cont = self.cont_features[idx].to(
dtype=torch.float) # (n_cont,)
cur_cate = self.cate_features[idx].to(
dtype=torch.long) # (n_cate,)
else:
cur_cont = None
cur_cate = None
sex = int(self._sex[idx])
return {
'event_seq': event_tensor,
'time_seq': time_tensor,
'labels': labels,
'loss_mask': loss_mask,
'sex': sex,
'death_time': death_time,
'global_censoring_time': global_censoring_time,
'cont_features': cur_cont,
'cate_features': cur_cate,
}
def health_collate_fn(batch):
# 1) pad variable-length sequences
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
event_seq_padded = pad_sequence(
event_seqs, batch_first=True, padding_value=0).long()
time_seq_padded = pad_sequence(
time_seqs, batch_first=True, padding_value=0.0).float()
attn_mask = (event_seq_padded != 0) # True for real tokens
# 2) stack fixed-size tensors
labels = torch.stack([item['labels'] for item in batch],
dim=0).float() # (B, n_disease, 2)
loss_mask = torch.stack([item['loss_mask']
# (B, n_disease)
for item in batch], dim=0).float()
# 3) scalar fields -> tensor
sex = torch.tensor([int(item['sex'])
for item in batch], dtype=torch.long) # (B,)
death_time = torch.tensor([float(item['death_time'])
for item in batch], dtype=torch.float32) # (B,)
global_censoring_time = torch.tensor(
[float(item['global_censoring_time']) for item in batch],
dtype=torch.float32
) # (B,)
# 4) optional static tabular features
cont_list = [item['cont_features'] for item in batch]
cate_list = [item['cate_features'] for item in batch]
has_cont = any(x is not None for x in cont_list)
has_cate = any(x is not None for x in cate_list)
if has_cont and any(x is None for x in cont_list):
raise ValueError("Mixed None/non-None cont_features in batch.")
if has_cate and any(x is None for x in cate_list):
raise ValueError("Mixed None/non-None cate_features in batch.")
cont_features = torch.stack(cont_list, dim=0).float(
) if has_cont else None # (B, n_cont)
cate_features = torch.stack(cate_list, dim=0).long(
) if has_cate else None # (B, n_cate)
return {
'event_seq': event_seq_padded,
'time_seq': time_seq_padded,
'attn_mask': attn_mask,
'labels': labels,
'loss_mask': loss_mask,
'sex': sex,
'death_time': death_time,
'global_censoring_time': global_censoring_time,
'cont_features': cont_features,
'cate_features': cate_features,
}