Implement DelphiBERT model and data preparation scripts for tabular time series analysis
- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features. - Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity. - Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
This commit is contained in:
240
dataset.py
Normal file
240
dataset.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
|
||||
class HealthDataset(Dataset):
|
||||
"""
|
||||
Dataset for health records.
|
||||
|
||||
Args:
|
||||
data_prefix (str): Prefix for data files.
|
||||
covariate_list (List[str] | None): List of covariates to include.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_prefix: str,
|
||||
max_len: int = 64,
|
||||
covariate_list: List[str] | None = None,
|
||||
):
|
||||
event_data = np.load(f"{data_prefix}_event_data.npy")
|
||||
basic_info = pd.read_csv(
|
||||
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||
cont_cols = []
|
||||
cate_cols = []
|
||||
self.cate_dims = []
|
||||
self.n_cont = 0
|
||||
self.n_cate = 0
|
||||
if covariate_list is not None:
|
||||
tabular_data = pd.read_csv(
|
||||
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
|
||||
tabular_data = tabular_data.convert_dtypes()
|
||||
for col in tabular_data.columns:
|
||||
if pd.api.types.is_float_dtype(tabular_data[col]):
|
||||
cont_cols.append(col)
|
||||
elif pd.api.types.is_integer_dtype(tabular_data[col]):
|
||||
series = tabular_data[col]
|
||||
unique_vals = series.dropna().unique()
|
||||
if len(unique_vals) > 11:
|
||||
cont_cols.append(col)
|
||||
else:
|
||||
cate_cols.append(col)
|
||||
self.cate_dims.append(int(series.max()) + 1)
|
||||
self.cont_features = tabular_data[cont_cols].to_numpy(
|
||||
dtype=np.float32).copy()
|
||||
self.cate_features = tabular_data[cate_cols].to_numpy(
|
||||
dtype=np.int64).copy()
|
||||
self.n_cont = self.cont_features.shape[1]
|
||||
self.n_cate = self.cate_features.shape[1]
|
||||
self.cont_features = torch.from_numpy(self.cont_features)
|
||||
self.cate_features = torch.from_numpy(self.cate_features)
|
||||
else:
|
||||
tabular_data = None
|
||||
self.cont_features = None
|
||||
self.cate_features = None
|
||||
|
||||
patient_events = defaultdict(list)
|
||||
vocab_size = 0
|
||||
for patient_id, time_in_days, event_code in event_data:
|
||||
patient_events[patient_id].append((time_in_days, event_code))
|
||||
if event_code > vocab_size:
|
||||
vocab_size = event_code
|
||||
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
|
||||
self.n_disease = self.n_event - 1 # Exclude death code
|
||||
self.max_len = max_len
|
||||
self.death_token = vocab_size
|
||||
self.basic_info = basic_info.convert_dtypes()
|
||||
self.patient_ids = self.basic_info.index.tolist()
|
||||
self.patient_events = dict(patient_events)
|
||||
for patient_id, records in self.patient_events.items():
|
||||
records.sort(key=lambda x: x[0])
|
||||
|
||||
self._doa = self.basic_info.loc[
|
||||
self.patient_ids, 'date_of_assessment'
|
||||
].to_numpy(dtype=np.float32)
|
||||
self._sex = self.basic_info.loc[
|
||||
self.patient_ids, 'sex'
|
||||
].to_numpy(dtype=np.int64)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = list(self.patient_events.get(patient_id, []))
|
||||
|
||||
if self.cont_features is not None or self.cate_features is not None:
|
||||
doa = float(self._doa[idx])
|
||||
# insert checkup event at date of assessment
|
||||
records.append((doa, 2)) # 1 is checkup token
|
||||
records.sort(key=lambda x: x[0])
|
||||
|
||||
n_events = len(records)
|
||||
has_died = (records[-1][1] ==
|
||||
self.death_token) if n_events > 0 else False
|
||||
|
||||
if has_died:
|
||||
valid_max_split = n_events - 1
|
||||
else:
|
||||
valid_max_split = n_events
|
||||
|
||||
if valid_max_split < 1:
|
||||
split_idx = 0
|
||||
else:
|
||||
split_idx = np.random.randint(1, valid_max_split + 1)
|
||||
|
||||
past_records = records[:split_idx]
|
||||
future_records = records[split_idx:]
|
||||
|
||||
if len(past_records) > 0:
|
||||
current_time = past_records[-1][0]
|
||||
else:
|
||||
current_time = 0
|
||||
|
||||
if len(past_records) > self.max_len:
|
||||
past_records = past_records[-self.max_len:]
|
||||
|
||||
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
|
||||
time_seq = [current_time] + [item[0] for item in past_records]
|
||||
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
|
||||
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
|
||||
global_censoring_time = records[-1][0] - current_time
|
||||
|
||||
labels[:, 1] = global_censoring_time
|
||||
|
||||
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
|
||||
|
||||
death_time = 0.0
|
||||
|
||||
for t, token in past_records:
|
||||
if token > 2:
|
||||
disease_idx = token - 3
|
||||
if 0 <= disease_idx < self.n_disease:
|
||||
loss_mask[disease_idx] = 0.0
|
||||
|
||||
for t, token in future_records:
|
||||
rel_time = t - current_time
|
||||
if rel_time < 1e-6:
|
||||
rel_time = 1e-6
|
||||
|
||||
if token == self.death_token:
|
||||
death_time = rel_time
|
||||
break
|
||||
|
||||
if token > 2:
|
||||
disease_idx = token - 3
|
||||
if 0 <= disease_idx < self.n_disease:
|
||||
if labels[disease_idx, 0] == 0.0:
|
||||
labels[disease_idx, 0] = 1.0
|
||||
labels[disease_idx, 1] = rel_time
|
||||
|
||||
if self.cont_features is not None:
|
||||
# 提取该病人的静态特征向量
|
||||
cur_cont = self.cont_features[idx].to(
|
||||
dtype=torch.float) # (n_cont,)
|
||||
cur_cate = self.cate_features[idx].to(
|
||||
dtype=torch.long) # (n_cate,)
|
||||
else:
|
||||
cur_cont = None
|
||||
cur_cate = None
|
||||
sex = int(self._sex[idx])
|
||||
|
||||
return {
|
||||
'event_seq': event_tensor,
|
||||
'time_seq': time_tensor,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'sex': sex,
|
||||
'death_time': death_time,
|
||||
'global_censoring_time': global_censoring_time,
|
||||
'cont_features': cur_cont,
|
||||
'cate_features': cur_cate,
|
||||
}
|
||||
|
||||
|
||||
def health_collate_fn(batch):
|
||||
# 1) pad variable-length sequences
|
||||
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
|
||||
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
|
||||
|
||||
event_seq_padded = pad_sequence(
|
||||
event_seqs, batch_first=True, padding_value=0).long()
|
||||
time_seq_padded = pad_sequence(
|
||||
time_seqs, batch_first=True, padding_value=0.0).float()
|
||||
|
||||
attn_mask = (event_seq_padded != 0) # True for real tokens
|
||||
|
||||
# 2) stack fixed-size tensors
|
||||
labels = torch.stack([item['labels'] for item in batch],
|
||||
dim=0).float() # (B, n_disease, 2)
|
||||
loss_mask = torch.stack([item['loss_mask']
|
||||
# (B, n_disease)
|
||||
for item in batch], dim=0).float()
|
||||
|
||||
# 3) scalar fields -> tensor
|
||||
sex = torch.tensor([int(item['sex'])
|
||||
for item in batch], dtype=torch.long) # (B,)
|
||||
death_time = torch.tensor([float(item['death_time'])
|
||||
for item in batch], dtype=torch.float32) # (B,)
|
||||
global_censoring_time = torch.tensor(
|
||||
[float(item['global_censoring_time']) for item in batch],
|
||||
dtype=torch.float32
|
||||
) # (B,)
|
||||
|
||||
# 4) optional static tabular features
|
||||
cont_list = [item['cont_features'] for item in batch]
|
||||
cate_list = [item['cate_features'] for item in batch]
|
||||
|
||||
has_cont = any(x is not None for x in cont_list)
|
||||
has_cate = any(x is not None for x in cate_list)
|
||||
|
||||
if has_cont and any(x is None for x in cont_list):
|
||||
raise ValueError("Mixed None/non-None cont_features in batch.")
|
||||
if has_cate and any(x is None for x in cate_list):
|
||||
raise ValueError("Mixed None/non-None cate_features in batch.")
|
||||
|
||||
cont_features = torch.stack(cont_list, dim=0).float(
|
||||
) if has_cont else None # (B, n_cont)
|
||||
cate_features = torch.stack(cate_list, dim=0).long(
|
||||
) if has_cate else None # (B, n_cate)
|
||||
|
||||
return {
|
||||
'event_seq': event_seq_padded,
|
||||
'time_seq': time_seq_padded,
|
||||
'attn_mask': attn_mask,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'sex': sex,
|
||||
'death_time': death_time,
|
||||
'global_censoring_time': global_censoring_time,
|
||||
'cont_features': cont_features,
|
||||
'cate_features': cate_features,
|
||||
}
|
||||
Reference in New Issue
Block a user