241 lines
8.6 KiB
Python
241 lines
8.6 KiB
Python
|
|
import torch
|
||
|
|
from torch.utils.data import Dataset
|
||
|
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from collections import defaultdict
|
||
|
|
from typing import Any, Dict, List, Sequence, Tuple
|
||
|
|
|
||
|
|
|
||
|
|
class HealthDataset(Dataset):
|
||
|
|
"""
|
||
|
|
Dataset for health records.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
data_prefix (str): Prefix for data files.
|
||
|
|
covariate_list (List[str] | None): List of covariates to include.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
data_prefix: str,
|
||
|
|
max_len: int = 64,
|
||
|
|
covariate_list: List[str] | None = None,
|
||
|
|
):
|
||
|
|
event_data = np.load(f"{data_prefix}_event_data.npy")
|
||
|
|
basic_info = pd.read_csv(
|
||
|
|
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||
|
|
cont_cols = []
|
||
|
|
cate_cols = []
|
||
|
|
self.cate_dims = []
|
||
|
|
self.n_cont = 0
|
||
|
|
self.n_cate = 0
|
||
|
|
if covariate_list is not None:
|
||
|
|
tabular_data = pd.read_csv(
|
||
|
|
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
|
||
|
|
tabular_data = tabular_data.convert_dtypes()
|
||
|
|
for col in tabular_data.columns:
|
||
|
|
if pd.api.types.is_float_dtype(tabular_data[col]):
|
||
|
|
cont_cols.append(col)
|
||
|
|
elif pd.api.types.is_integer_dtype(tabular_data[col]):
|
||
|
|
series = tabular_data[col]
|
||
|
|
unique_vals = series.dropna().unique()
|
||
|
|
if len(unique_vals) > 11:
|
||
|
|
cont_cols.append(col)
|
||
|
|
else:
|
||
|
|
cate_cols.append(col)
|
||
|
|
self.cate_dims.append(int(series.max()) + 1)
|
||
|
|
self.cont_features = tabular_data[cont_cols].to_numpy(
|
||
|
|
dtype=np.float32).copy()
|
||
|
|
self.cate_features = tabular_data[cate_cols].to_numpy(
|
||
|
|
dtype=np.int64).copy()
|
||
|
|
self.n_cont = self.cont_features.shape[1]
|
||
|
|
self.n_cate = self.cate_features.shape[1]
|
||
|
|
self.cont_features = torch.from_numpy(self.cont_features)
|
||
|
|
self.cate_features = torch.from_numpy(self.cate_features)
|
||
|
|
else:
|
||
|
|
tabular_data = None
|
||
|
|
self.cont_features = None
|
||
|
|
self.cate_features = None
|
||
|
|
|
||
|
|
patient_events = defaultdict(list)
|
||
|
|
vocab_size = 0
|
||
|
|
for patient_id, time_in_days, event_code in event_data:
|
||
|
|
patient_events[patient_id].append((time_in_days, event_code))
|
||
|
|
if event_code > vocab_size:
|
||
|
|
vocab_size = event_code
|
||
|
|
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
|
||
|
|
self.n_disease = self.n_event - 1 # Exclude death code
|
||
|
|
self.max_len = max_len
|
||
|
|
self.death_token = vocab_size
|
||
|
|
self.basic_info = basic_info.convert_dtypes()
|
||
|
|
self.patient_ids = self.basic_info.index.tolist()
|
||
|
|
self.patient_events = dict(patient_events)
|
||
|
|
for patient_id, records in self.patient_events.items():
|
||
|
|
records.sort(key=lambda x: x[0])
|
||
|
|
|
||
|
|
self._doa = self.basic_info.loc[
|
||
|
|
self.patient_ids, 'date_of_assessment'
|
||
|
|
].to_numpy(dtype=np.float32)
|
||
|
|
self._sex = self.basic_info.loc[
|
||
|
|
self.patient_ids, 'sex'
|
||
|
|
].to_numpy(dtype=np.int64)
|
||
|
|
|
||
|
|
def __len__(self) -> int:
|
||
|
|
return len(self.patient_ids)
|
||
|
|
|
||
|
|
def __getitem__(self, idx: int):
|
||
|
|
patient_id = self.patient_ids[idx]
|
||
|
|
records = list(self.patient_events.get(patient_id, []))
|
||
|
|
|
||
|
|
if self.cont_features is not None or self.cate_features is not None:
|
||
|
|
doa = float(self._doa[idx])
|
||
|
|
# insert checkup event at date of assessment
|
||
|
|
records.append((doa, 2)) # 1 is checkup token
|
||
|
|
records.sort(key=lambda x: x[0])
|
||
|
|
|
||
|
|
n_events = len(records)
|
||
|
|
has_died = (records[-1][1] ==
|
||
|
|
self.death_token) if n_events > 0 else False
|
||
|
|
|
||
|
|
if has_died:
|
||
|
|
valid_max_split = n_events - 1
|
||
|
|
else:
|
||
|
|
valid_max_split = n_events
|
||
|
|
|
||
|
|
if valid_max_split < 1:
|
||
|
|
split_idx = 0
|
||
|
|
else:
|
||
|
|
split_idx = np.random.randint(1, valid_max_split + 1)
|
||
|
|
|
||
|
|
past_records = records[:split_idx]
|
||
|
|
future_records = records[split_idx:]
|
||
|
|
|
||
|
|
if len(past_records) > 0:
|
||
|
|
current_time = past_records[-1][0]
|
||
|
|
else:
|
||
|
|
current_time = 0
|
||
|
|
|
||
|
|
if len(past_records) > self.max_len:
|
||
|
|
past_records = past_records[-self.max_len:]
|
||
|
|
|
||
|
|
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
|
||
|
|
time_seq = [current_time] + [item[0] for item in past_records]
|
||
|
|
|
||
|
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||
|
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||
|
|
|
||
|
|
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
|
||
|
|
global_censoring_time = records[-1][0] - current_time
|
||
|
|
|
||
|
|
labels[:, 1] = global_censoring_time
|
||
|
|
|
||
|
|
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
|
||
|
|
|
||
|
|
death_time = 0.0
|
||
|
|
|
||
|
|
for t, token in past_records:
|
||
|
|
if token > 2:
|
||
|
|
disease_idx = token - 3
|
||
|
|
if 0 <= disease_idx < self.n_disease:
|
||
|
|
loss_mask[disease_idx] = 0.0
|
||
|
|
|
||
|
|
for t, token in future_records:
|
||
|
|
rel_time = t - current_time
|
||
|
|
if rel_time < 1e-6:
|
||
|
|
rel_time = 1e-6
|
||
|
|
|
||
|
|
if token == self.death_token:
|
||
|
|
death_time = rel_time
|
||
|
|
break
|
||
|
|
|
||
|
|
if token > 2:
|
||
|
|
disease_idx = token - 3
|
||
|
|
if 0 <= disease_idx < self.n_disease:
|
||
|
|
if labels[disease_idx, 0] == 0.0:
|
||
|
|
labels[disease_idx, 0] = 1.0
|
||
|
|
labels[disease_idx, 1] = rel_time
|
||
|
|
|
||
|
|
if self.cont_features is not None:
|
||
|
|
# 提取该病人的静态特征向量
|
||
|
|
cur_cont = self.cont_features[idx].to(
|
||
|
|
dtype=torch.float) # (n_cont,)
|
||
|
|
cur_cate = self.cate_features[idx].to(
|
||
|
|
dtype=torch.long) # (n_cate,)
|
||
|
|
else:
|
||
|
|
cur_cont = None
|
||
|
|
cur_cate = None
|
||
|
|
sex = int(self._sex[idx])
|
||
|
|
|
||
|
|
return {
|
||
|
|
'event_seq': event_tensor,
|
||
|
|
'time_seq': time_tensor,
|
||
|
|
'labels': labels,
|
||
|
|
'loss_mask': loss_mask,
|
||
|
|
'sex': sex,
|
||
|
|
'death_time': death_time,
|
||
|
|
'global_censoring_time': global_censoring_time,
|
||
|
|
'cont_features': cur_cont,
|
||
|
|
'cate_features': cur_cate,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def health_collate_fn(batch):
|
||
|
|
# 1) pad variable-length sequences
|
||
|
|
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
|
||
|
|
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
|
||
|
|
|
||
|
|
event_seq_padded = pad_sequence(
|
||
|
|
event_seqs, batch_first=True, padding_value=0).long()
|
||
|
|
time_seq_padded = pad_sequence(
|
||
|
|
time_seqs, batch_first=True, padding_value=0.0).float()
|
||
|
|
|
||
|
|
attn_mask = (event_seq_padded != 0) # True for real tokens
|
||
|
|
|
||
|
|
# 2) stack fixed-size tensors
|
||
|
|
labels = torch.stack([item['labels'] for item in batch],
|
||
|
|
dim=0).float() # (B, n_disease, 2)
|
||
|
|
loss_mask = torch.stack([item['loss_mask']
|
||
|
|
# (B, n_disease)
|
||
|
|
for item in batch], dim=0).float()
|
||
|
|
|
||
|
|
# 3) scalar fields -> tensor
|
||
|
|
sex = torch.tensor([int(item['sex'])
|
||
|
|
for item in batch], dtype=torch.long) # (B,)
|
||
|
|
death_time = torch.tensor([float(item['death_time'])
|
||
|
|
for item in batch], dtype=torch.float32) # (B,)
|
||
|
|
global_censoring_time = torch.tensor(
|
||
|
|
[float(item['global_censoring_time']) for item in batch],
|
||
|
|
dtype=torch.float32
|
||
|
|
) # (B,)
|
||
|
|
|
||
|
|
# 4) optional static tabular features
|
||
|
|
cont_list = [item['cont_features'] for item in batch]
|
||
|
|
cate_list = [item['cate_features'] for item in batch]
|
||
|
|
|
||
|
|
has_cont = any(x is not None for x in cont_list)
|
||
|
|
has_cate = any(x is not None for x in cate_list)
|
||
|
|
|
||
|
|
if has_cont and any(x is None for x in cont_list):
|
||
|
|
raise ValueError("Mixed None/non-None cont_features in batch.")
|
||
|
|
if has_cate and any(x is None for x in cate_list):
|
||
|
|
raise ValueError("Mixed None/non-None cate_features in batch.")
|
||
|
|
|
||
|
|
cont_features = torch.stack(cont_list, dim=0).float(
|
||
|
|
) if has_cont else None # (B, n_cont)
|
||
|
|
cate_features = torch.stack(cate_list, dim=0).long(
|
||
|
|
) if has_cate else None # (B, n_cate)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'event_seq': event_seq_padded,
|
||
|
|
'time_seq': time_seq_padded,
|
||
|
|
'attn_mask': attn_mask,
|
||
|
|
'labels': labels,
|
||
|
|
'loss_mask': loss_mask,
|
||
|
|
'sex': sex,
|
||
|
|
'death_time': death_time,
|
||
|
|
'global_censoring_time': global_censoring_time,
|
||
|
|
'cont_features': cont_features,
|
||
|
|
'cate_features': cate_features,
|
||
|
|
}
|