- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features. - Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity. - Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
241 lines
8.6 KiB
Python
241 lines
8.6 KiB
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
import pandas as pd
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Sequence, Tuple
|
|
|
|
|
|
class HealthDataset(Dataset):
|
|
"""
|
|
Dataset for health records.
|
|
|
|
Args:
|
|
data_prefix (str): Prefix for data files.
|
|
covariate_list (List[str] | None): List of covariates to include.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_prefix: str,
|
|
max_len: int = 64,
|
|
covariate_list: List[str] | None = None,
|
|
):
|
|
event_data = np.load(f"{data_prefix}_event_data.npy")
|
|
basic_info = pd.read_csv(
|
|
f"{data_prefix}_basic_info.csv", index_col='eid')
|
|
cont_cols = []
|
|
cate_cols = []
|
|
self.cate_dims = []
|
|
self.n_cont = 0
|
|
self.n_cate = 0
|
|
if covariate_list is not None:
|
|
tabular_data = pd.read_csv(
|
|
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
|
|
tabular_data = tabular_data.convert_dtypes()
|
|
for col in tabular_data.columns:
|
|
if pd.api.types.is_float_dtype(tabular_data[col]):
|
|
cont_cols.append(col)
|
|
elif pd.api.types.is_integer_dtype(tabular_data[col]):
|
|
series = tabular_data[col]
|
|
unique_vals = series.dropna().unique()
|
|
if len(unique_vals) > 11:
|
|
cont_cols.append(col)
|
|
else:
|
|
cate_cols.append(col)
|
|
self.cate_dims.append(int(series.max()) + 1)
|
|
self.cont_features = tabular_data[cont_cols].to_numpy(
|
|
dtype=np.float32).copy()
|
|
self.cate_features = tabular_data[cate_cols].to_numpy(
|
|
dtype=np.int64).copy()
|
|
self.n_cont = self.cont_features.shape[1]
|
|
self.n_cate = self.cate_features.shape[1]
|
|
self.cont_features = torch.from_numpy(self.cont_features)
|
|
self.cate_features = torch.from_numpy(self.cate_features)
|
|
else:
|
|
tabular_data = None
|
|
self.cont_features = None
|
|
self.cate_features = None
|
|
|
|
patient_events = defaultdict(list)
|
|
vocab_size = 0
|
|
for patient_id, time_in_days, event_code in event_data:
|
|
patient_events[patient_id].append((time_in_days, event_code))
|
|
if event_code > vocab_size:
|
|
vocab_size = event_code
|
|
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
|
|
self.n_disease = self.n_event - 1 # Exclude death code
|
|
self.max_len = max_len
|
|
self.death_token = vocab_size
|
|
self.basic_info = basic_info.convert_dtypes()
|
|
self.patient_ids = self.basic_info.index.tolist()
|
|
self.patient_events = dict(patient_events)
|
|
for patient_id, records in self.patient_events.items():
|
|
records.sort(key=lambda x: x[0])
|
|
|
|
self._doa = self.basic_info.loc[
|
|
self.patient_ids, 'date_of_assessment'
|
|
].to_numpy(dtype=np.float32)
|
|
self._sex = self.basic_info.loc[
|
|
self.patient_ids, 'sex'
|
|
].to_numpy(dtype=np.int64)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.patient_ids)
|
|
|
|
def __getitem__(self, idx: int):
|
|
patient_id = self.patient_ids[idx]
|
|
records = list(self.patient_events.get(patient_id, []))
|
|
|
|
if self.cont_features is not None or self.cate_features is not None:
|
|
doa = float(self._doa[idx])
|
|
# insert checkup event at date of assessment
|
|
records.append((doa, 2)) # 1 is checkup token
|
|
records.sort(key=lambda x: x[0])
|
|
|
|
n_events = len(records)
|
|
has_died = (records[-1][1] ==
|
|
self.death_token) if n_events > 0 else False
|
|
|
|
if has_died:
|
|
valid_max_split = n_events - 1
|
|
else:
|
|
valid_max_split = n_events
|
|
|
|
if valid_max_split < 1:
|
|
split_idx = 0
|
|
else:
|
|
split_idx = np.random.randint(1, valid_max_split + 1)
|
|
|
|
past_records = records[:split_idx]
|
|
future_records = records[split_idx:]
|
|
|
|
if len(past_records) > 0:
|
|
current_time = past_records[-1][0]
|
|
else:
|
|
current_time = 0
|
|
|
|
if len(past_records) > self.max_len:
|
|
past_records = past_records[-self.max_len:]
|
|
|
|
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
|
|
time_seq = [current_time] + [item[0] for item in past_records]
|
|
|
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
|
|
|
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
|
|
global_censoring_time = records[-1][0] - current_time
|
|
|
|
labels[:, 1] = global_censoring_time
|
|
|
|
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
|
|
|
|
death_time = 0.0
|
|
|
|
for t, token in past_records:
|
|
if token > 2:
|
|
disease_idx = token - 3
|
|
if 0 <= disease_idx < self.n_disease:
|
|
loss_mask[disease_idx] = 0.0
|
|
|
|
for t, token in future_records:
|
|
rel_time = t - current_time
|
|
if rel_time < 1e-6:
|
|
rel_time = 1e-6
|
|
|
|
if token == self.death_token:
|
|
death_time = rel_time
|
|
break
|
|
|
|
if token > 2:
|
|
disease_idx = token - 3
|
|
if 0 <= disease_idx < self.n_disease:
|
|
if labels[disease_idx, 0] == 0.0:
|
|
labels[disease_idx, 0] = 1.0
|
|
labels[disease_idx, 1] = rel_time
|
|
|
|
if self.cont_features is not None:
|
|
# 提取该病人的静态特征向量
|
|
cur_cont = self.cont_features[idx].to(
|
|
dtype=torch.float) # (n_cont,)
|
|
cur_cate = self.cate_features[idx].to(
|
|
dtype=torch.long) # (n_cate,)
|
|
else:
|
|
cur_cont = None
|
|
cur_cate = None
|
|
sex = int(self._sex[idx])
|
|
|
|
return {
|
|
'event_seq': event_tensor,
|
|
'time_seq': time_tensor,
|
|
'labels': labels,
|
|
'loss_mask': loss_mask,
|
|
'sex': sex,
|
|
'death_time': death_time,
|
|
'global_censoring_time': global_censoring_time,
|
|
'cont_features': cur_cont,
|
|
'cate_features': cur_cate,
|
|
}
|
|
|
|
|
|
def health_collate_fn(batch):
|
|
# 1) pad variable-length sequences
|
|
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
|
|
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
|
|
|
|
event_seq_padded = pad_sequence(
|
|
event_seqs, batch_first=True, padding_value=0).long()
|
|
time_seq_padded = pad_sequence(
|
|
time_seqs, batch_first=True, padding_value=0.0).float()
|
|
|
|
attn_mask = (event_seq_padded != 0) # True for real tokens
|
|
|
|
# 2) stack fixed-size tensors
|
|
labels = torch.stack([item['labels'] for item in batch],
|
|
dim=0).float() # (B, n_disease, 2)
|
|
loss_mask = torch.stack([item['loss_mask']
|
|
# (B, n_disease)
|
|
for item in batch], dim=0).float()
|
|
|
|
# 3) scalar fields -> tensor
|
|
sex = torch.tensor([int(item['sex'])
|
|
for item in batch], dtype=torch.long) # (B,)
|
|
death_time = torch.tensor([float(item['death_time'])
|
|
for item in batch], dtype=torch.float32) # (B,)
|
|
global_censoring_time = torch.tensor(
|
|
[float(item['global_censoring_time']) for item in batch],
|
|
dtype=torch.float32
|
|
) # (B,)
|
|
|
|
# 4) optional static tabular features
|
|
cont_list = [item['cont_features'] for item in batch]
|
|
cate_list = [item['cate_features'] for item in batch]
|
|
|
|
has_cont = any(x is not None for x in cont_list)
|
|
has_cate = any(x is not None for x in cate_list)
|
|
|
|
if has_cont and any(x is None for x in cont_list):
|
|
raise ValueError("Mixed None/non-None cont_features in batch.")
|
|
if has_cate and any(x is None for x in cate_list):
|
|
raise ValueError("Mixed None/non-None cate_features in batch.")
|
|
|
|
cont_features = torch.stack(cont_list, dim=0).float(
|
|
) if has_cont else None # (B, n_cont)
|
|
cate_features = torch.stack(cate_list, dim=0).long(
|
|
) if has_cate else None # (B, n_cate)
|
|
|
|
return {
|
|
'event_seq': event_seq_padded,
|
|
'time_seq': time_seq_padded,
|
|
'attn_mask': attn_mask,
|
|
'labels': labels,
|
|
'loss_mask': loss_mask,
|
|
'sex': sex,
|
|
'death_time': death_time,
|
|
'global_censoring_time': global_censoring_time,
|
|
'cont_features': cont_features,
|
|
'cate_features': cate_features,
|
|
}
|