Implement DelphiBERT model and data preparation scripts for tabular time series analysis
- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features. - Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity. - Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
This commit is contained in:
55
age_encoder.py
Normal file
55
age_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class AgeSinusoidalEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Sinusoidal encoder for age.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension. Must be even.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_embd: int):
|
||||||
|
super().__init__()
|
||||||
|
if n_embd % 2 != 0:
|
||||||
|
raise ValueError("n_embd must be even for sinusoidal encoding.")
|
||||||
|
self.n_embd = n_embd
|
||||||
|
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
|
||||||
|
divisor = torch.pow(10000, i / self.n_embd)
|
||||||
|
self.register_buffer('divisor', divisor)
|
||||||
|
|
||||||
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||||
|
t_years = ages / 365.25
|
||||||
|
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
|
||||||
|
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
|
||||||
|
# Interleave cos and sin along the last dimension
|
||||||
|
output = torch.zeros(
|
||||||
|
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
|
||||||
|
output[:, :, 0::2] = torch.cos(args)
|
||||||
|
output[:, :, 1::2] = torch.sin(args)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AgeMLPEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
MLP encoder for age, using sinusoidal encoding as input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_embd: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(n_embd, 4 * n_embd),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(4 * n_embd, n_embd),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.sin_encoder(ages)
|
||||||
|
output = self.mlp(x)
|
||||||
|
return output
|
||||||
112
backbones.py
Normal file
112
backbones.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-head self-attention mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension.
|
||||||
|
n_head (int): Number of attention heads.
|
||||||
|
attn_pdrop (float): Attention dropout probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_embd: int,
|
||||||
|
n_head: int,
|
||||||
|
attn_pdrop: float = 0.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
|
||||||
|
self.n_head = n_head
|
||||||
|
self.head_dim = n_embd // n_head
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
||||||
|
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, L, D = x.shape
|
||||||
|
qkv = self.qkv_proj(x) # (B, L, 3D)
|
||||||
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
def reshape_heads(t):
|
||||||
|
# (B, H, L, d)
|
||||||
|
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
q = reshape_heads(q)
|
||||||
|
k = reshape_heads(k)
|
||||||
|
v = reshape_heads(v)
|
||||||
|
|
||||||
|
dropout_p = self.attn_pdrop if self.training else 0.0
|
||||||
|
attn = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
) # (B, H, L, d)
|
||||||
|
|
||||||
|
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
||||||
|
return self.o_proj(attn)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer block consisting of self-attention and MLP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension.
|
||||||
|
n_head (int): Number of attention heads.
|
||||||
|
pdrop (float): Dropout probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_embd: int,
|
||||||
|
n_head: int,
|
||||||
|
pdrop: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
attn_pdrop = pdrop
|
||||||
|
|
||||||
|
self.norm_1 = nn.LayerNorm(n_embd)
|
||||||
|
self.attn = SelfAttention(
|
||||||
|
n_embd=n_embd,
|
||||||
|
n_head=n_head,
|
||||||
|
attn_pdrop=attn_pdrop,
|
||||||
|
)
|
||||||
|
self.norm_2 = nn.LayerNorm(n_embd)
|
||||||
|
self.mlp = nn.ModuleDict(dict(
|
||||||
|
c_fc=nn.Linear(n_embd, 4 * n_embd),
|
||||||
|
c_proj=nn.Linear(4 * n_embd, n_embd),
|
||||||
|
act=nn.GELU(),
|
||||||
|
dropout=nn.Dropout(pdrop),
|
||||||
|
))
|
||||||
|
m = self.mlp
|
||||||
|
self.mlpf = lambda x: m.dropout(
|
||||||
|
m.c_proj(m.act(m.c_fc(x))))
|
||||||
|
self.resid_dropout = nn.Dropout(pdrop)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Attention
|
||||||
|
h = self.norm_1(x)
|
||||||
|
h = self.attn(h, attn_mask=attn_mask)
|
||||||
|
x = x + self.resid_dropout(h)
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
h = self.norm_2(x)
|
||||||
|
h = self.mlpf(h)
|
||||||
|
x = x + self.resid_dropout(h)
|
||||||
|
|
||||||
|
return x
|
||||||
240
dataset.py
Normal file
240
dataset.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class HealthDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for health records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_prefix (str): Prefix for data files.
|
||||||
|
covariate_list (List[str] | None): List of covariates to include.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_prefix: str,
|
||||||
|
max_len: int = 64,
|
||||||
|
covariate_list: List[str] | None = None,
|
||||||
|
):
|
||||||
|
event_data = np.load(f"{data_prefix}_event_data.npy")
|
||||||
|
basic_info = pd.read_csv(
|
||||||
|
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||||
|
cont_cols = []
|
||||||
|
cate_cols = []
|
||||||
|
self.cate_dims = []
|
||||||
|
self.n_cont = 0
|
||||||
|
self.n_cate = 0
|
||||||
|
if covariate_list is not None:
|
||||||
|
tabular_data = pd.read_csv(
|
||||||
|
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
|
||||||
|
tabular_data = tabular_data.convert_dtypes()
|
||||||
|
for col in tabular_data.columns:
|
||||||
|
if pd.api.types.is_float_dtype(tabular_data[col]):
|
||||||
|
cont_cols.append(col)
|
||||||
|
elif pd.api.types.is_integer_dtype(tabular_data[col]):
|
||||||
|
series = tabular_data[col]
|
||||||
|
unique_vals = series.dropna().unique()
|
||||||
|
if len(unique_vals) > 11:
|
||||||
|
cont_cols.append(col)
|
||||||
|
else:
|
||||||
|
cate_cols.append(col)
|
||||||
|
self.cate_dims.append(int(series.max()) + 1)
|
||||||
|
self.cont_features = tabular_data[cont_cols].to_numpy(
|
||||||
|
dtype=np.float32).copy()
|
||||||
|
self.cate_features = tabular_data[cate_cols].to_numpy(
|
||||||
|
dtype=np.int64).copy()
|
||||||
|
self.n_cont = self.cont_features.shape[1]
|
||||||
|
self.n_cate = self.cate_features.shape[1]
|
||||||
|
self.cont_features = torch.from_numpy(self.cont_features)
|
||||||
|
self.cate_features = torch.from_numpy(self.cate_features)
|
||||||
|
else:
|
||||||
|
tabular_data = None
|
||||||
|
self.cont_features = None
|
||||||
|
self.cate_features = None
|
||||||
|
|
||||||
|
patient_events = defaultdict(list)
|
||||||
|
vocab_size = 0
|
||||||
|
for patient_id, time_in_days, event_code in event_data:
|
||||||
|
patient_events[patient_id].append((time_in_days, event_code))
|
||||||
|
if event_code > vocab_size:
|
||||||
|
vocab_size = event_code
|
||||||
|
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
|
||||||
|
self.n_disease = self.n_event - 1 # Exclude death code
|
||||||
|
self.max_len = max_len
|
||||||
|
self.death_token = vocab_size
|
||||||
|
self.basic_info = basic_info.convert_dtypes()
|
||||||
|
self.patient_ids = self.basic_info.index.tolist()
|
||||||
|
self.patient_events = dict(patient_events)
|
||||||
|
for patient_id, records in self.patient_events.items():
|
||||||
|
records.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
self._doa = self.basic_info.loc[
|
||||||
|
self.patient_ids, 'date_of_assessment'
|
||||||
|
].to_numpy(dtype=np.float32)
|
||||||
|
self._sex = self.basic_info.loc[
|
||||||
|
self.patient_ids, 'sex'
|
||||||
|
].to_numpy(dtype=np.int64)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.patient_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
patient_id = self.patient_ids[idx]
|
||||||
|
records = list(self.patient_events.get(patient_id, []))
|
||||||
|
|
||||||
|
if self.cont_features is not None or self.cate_features is not None:
|
||||||
|
doa = float(self._doa[idx])
|
||||||
|
# insert checkup event at date of assessment
|
||||||
|
records.append((doa, 2)) # 1 is checkup token
|
||||||
|
records.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
n_events = len(records)
|
||||||
|
has_died = (records[-1][1] ==
|
||||||
|
self.death_token) if n_events > 0 else False
|
||||||
|
|
||||||
|
if has_died:
|
||||||
|
valid_max_split = n_events - 1
|
||||||
|
else:
|
||||||
|
valid_max_split = n_events
|
||||||
|
|
||||||
|
if valid_max_split < 1:
|
||||||
|
split_idx = 0
|
||||||
|
else:
|
||||||
|
split_idx = np.random.randint(1, valid_max_split + 1)
|
||||||
|
|
||||||
|
past_records = records[:split_idx]
|
||||||
|
future_records = records[split_idx:]
|
||||||
|
|
||||||
|
if len(past_records) > 0:
|
||||||
|
current_time = past_records[-1][0]
|
||||||
|
else:
|
||||||
|
current_time = 0
|
||||||
|
|
||||||
|
if len(past_records) > self.max_len:
|
||||||
|
past_records = past_records[-self.max_len:]
|
||||||
|
|
||||||
|
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
|
||||||
|
time_seq = [current_time] + [item[0] for item in past_records]
|
||||||
|
|
||||||
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||||
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||||
|
|
||||||
|
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
|
||||||
|
global_censoring_time = records[-1][0] - current_time
|
||||||
|
|
||||||
|
labels[:, 1] = global_censoring_time
|
||||||
|
|
||||||
|
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
|
||||||
|
|
||||||
|
death_time = 0.0
|
||||||
|
|
||||||
|
for t, token in past_records:
|
||||||
|
if token > 2:
|
||||||
|
disease_idx = token - 3
|
||||||
|
if 0 <= disease_idx < self.n_disease:
|
||||||
|
loss_mask[disease_idx] = 0.0
|
||||||
|
|
||||||
|
for t, token in future_records:
|
||||||
|
rel_time = t - current_time
|
||||||
|
if rel_time < 1e-6:
|
||||||
|
rel_time = 1e-6
|
||||||
|
|
||||||
|
if token == self.death_token:
|
||||||
|
death_time = rel_time
|
||||||
|
break
|
||||||
|
|
||||||
|
if token > 2:
|
||||||
|
disease_idx = token - 3
|
||||||
|
if 0 <= disease_idx < self.n_disease:
|
||||||
|
if labels[disease_idx, 0] == 0.0:
|
||||||
|
labels[disease_idx, 0] = 1.0
|
||||||
|
labels[disease_idx, 1] = rel_time
|
||||||
|
|
||||||
|
if self.cont_features is not None:
|
||||||
|
# 提取该病人的静态特征向量
|
||||||
|
cur_cont = self.cont_features[idx].to(
|
||||||
|
dtype=torch.float) # (n_cont,)
|
||||||
|
cur_cate = self.cate_features[idx].to(
|
||||||
|
dtype=torch.long) # (n_cate,)
|
||||||
|
else:
|
||||||
|
cur_cont = None
|
||||||
|
cur_cate = None
|
||||||
|
sex = int(self._sex[idx])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'event_seq': event_tensor,
|
||||||
|
'time_seq': time_tensor,
|
||||||
|
'labels': labels,
|
||||||
|
'loss_mask': loss_mask,
|
||||||
|
'sex': sex,
|
||||||
|
'death_time': death_time,
|
||||||
|
'global_censoring_time': global_censoring_time,
|
||||||
|
'cont_features': cur_cont,
|
||||||
|
'cate_features': cur_cate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def health_collate_fn(batch):
|
||||||
|
# 1) pad variable-length sequences
|
||||||
|
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
|
||||||
|
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
|
||||||
|
|
||||||
|
event_seq_padded = pad_sequence(
|
||||||
|
event_seqs, batch_first=True, padding_value=0).long()
|
||||||
|
time_seq_padded = pad_sequence(
|
||||||
|
time_seqs, batch_first=True, padding_value=0.0).float()
|
||||||
|
|
||||||
|
attn_mask = (event_seq_padded != 0) # True for real tokens
|
||||||
|
|
||||||
|
# 2) stack fixed-size tensors
|
||||||
|
labels = torch.stack([item['labels'] for item in batch],
|
||||||
|
dim=0).float() # (B, n_disease, 2)
|
||||||
|
loss_mask = torch.stack([item['loss_mask']
|
||||||
|
# (B, n_disease)
|
||||||
|
for item in batch], dim=0).float()
|
||||||
|
|
||||||
|
# 3) scalar fields -> tensor
|
||||||
|
sex = torch.tensor([int(item['sex'])
|
||||||
|
for item in batch], dtype=torch.long) # (B,)
|
||||||
|
death_time = torch.tensor([float(item['death_time'])
|
||||||
|
for item in batch], dtype=torch.float32) # (B,)
|
||||||
|
global_censoring_time = torch.tensor(
|
||||||
|
[float(item['global_censoring_time']) for item in batch],
|
||||||
|
dtype=torch.float32
|
||||||
|
) # (B,)
|
||||||
|
|
||||||
|
# 4) optional static tabular features
|
||||||
|
cont_list = [item['cont_features'] for item in batch]
|
||||||
|
cate_list = [item['cate_features'] for item in batch]
|
||||||
|
|
||||||
|
has_cont = any(x is not None for x in cont_list)
|
||||||
|
has_cate = any(x is not None for x in cate_list)
|
||||||
|
|
||||||
|
if has_cont and any(x is None for x in cont_list):
|
||||||
|
raise ValueError("Mixed None/non-None cont_features in batch.")
|
||||||
|
if has_cate and any(x is None for x in cate_list):
|
||||||
|
raise ValueError("Mixed None/non-None cate_features in batch.")
|
||||||
|
|
||||||
|
cont_features = torch.stack(cont_list, dim=0).float(
|
||||||
|
) if has_cont else None # (B, n_cont)
|
||||||
|
cate_features = torch.stack(cate_list, dim=0).long(
|
||||||
|
) if has_cate else None # (B, n_cate)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'event_seq': event_seq_padded,
|
||||||
|
'time_seq': time_seq_padded,
|
||||||
|
'attn_mask': attn_mask,
|
||||||
|
'labels': labels,
|
||||||
|
'loss_mask': loss_mask,
|
||||||
|
'sex': sex,
|
||||||
|
'death_time': death_time,
|
||||||
|
'global_censoring_time': global_censoring_time,
|
||||||
|
'cont_features': cont_features,
|
||||||
|
'cate_features': cate_features,
|
||||||
|
}
|
||||||
1129
icd10_codes_mod.tsv
Normal file
1129
icd10_codes_mod.tsv
Normal file
File diff suppressed because it is too large
Load Diff
1257
labels.csv
Normal file
1257
labels.csv
Normal file
File diff suppressed because it is too large
Load Diff
376
model.py
Normal file
376
model.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Optional, List
|
||||||
|
from backbones import Block
|
||||||
|
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TabularEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Encoder for tabular features (continuous and categorical).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension.
|
||||||
|
n_cont (int): Number of continuous features.
|
||||||
|
n_cate (int): Number of categorical features.
|
||||||
|
cate_dims (List[int]): List of dimensions for each categorical feature.
|
||||||
|
n_bins (int): Number of soft bins for continuous AutoDiscretization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_embd: int,
|
||||||
|
n_cont: int,
|
||||||
|
n_cate: int,
|
||||||
|
cate_dims: List[int],
|
||||||
|
n_bins: int = 16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_cont = n_cont
|
||||||
|
self.n_cate = n_cate
|
||||||
|
|
||||||
|
# Continuous feature path
|
||||||
|
# - BatchNorm on raw (NaN-filled) continuous values
|
||||||
|
# - AutoDiscretization (soft binning) per feature
|
||||||
|
if n_cont > 0:
|
||||||
|
self.cont_bn = nn.BatchNorm1d(n_cont)
|
||||||
|
self.cont_discretizer = AutoDiscretization(
|
||||||
|
n_features=n_cont,
|
||||||
|
n_bins=n_bins,
|
||||||
|
n_embd=n_embd,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cont_bn = None
|
||||||
|
self.cont_discretizer = None
|
||||||
|
|
||||||
|
if n_cate > 0:
|
||||||
|
assert len(cate_dims) == n_cate, \
|
||||||
|
"Length of cate_dims must match n_cate"
|
||||||
|
self.cate_embds = nn.ModuleList([
|
||||||
|
nn.Embedding(dim, n_embd) for dim in cate_dims
|
||||||
|
])
|
||||||
|
self.cate_mask_embds = nn.ModuleList([
|
||||||
|
nn.Embedding(2, n_embd) for _ in range(n_cate)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.cate_embds = None
|
||||||
|
self.cate_mask_embds = None
|
||||||
|
|
||||||
|
self.cont_mask_proj = (
|
||||||
|
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fuse aggregated value + aggregated mask via MLP
|
||||||
|
self.fuse_mlp = nn.Sequential(
|
||||||
|
nn.Linear(2 * n_embd, 2 * n_embd),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(2 * n_embd, n_embd),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
self.out_ln = nn.LayerNorm(n_embd)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
cont_features: Optional[torch.Tensor],
|
||||||
|
cate_features: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Encode tabular features into a per-timestep embedding.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
|
||||||
|
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, L, n_embd) encoded embedding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.n_cont == 0 and self.n_cate == 0:
|
||||||
|
# infer (B, L) from whichever input is not None
|
||||||
|
if cont_features is not None:
|
||||||
|
B, L = cont_features.shape[:2]
|
||||||
|
device = cont_features.device
|
||||||
|
elif cate_features is not None:
|
||||||
|
B, L = cate_features.shape[:2]
|
||||||
|
device = cate_features.device
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"TabularEncoder received no features but cannot infer (B, L)."
|
||||||
|
)
|
||||||
|
return torch.zeros(B, L, self.n_embd, device=device)
|
||||||
|
|
||||||
|
value_parts: List[torch.Tensor] = []
|
||||||
|
mask_parts: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
if self.n_cont > 0 and cont_features is not None:
|
||||||
|
if cont_features.dim() != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"cont_features must be 3D tensor (B, L, n_cont)")
|
||||||
|
B, L, D_cont = cont_features.shape
|
||||||
|
if D_cont != self.n_cont:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
|
||||||
|
|
||||||
|
# Missingness mask: 1 for valid, 0 for missing
|
||||||
|
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
|
||||||
|
|
||||||
|
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
|
||||||
|
cont_filled = torch.nan_to_num(
|
||||||
|
cont_features, nan=0.0) # (B, L, n_cont)
|
||||||
|
|
||||||
|
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
|
||||||
|
cont_flat = cont_filled.reshape(-1, self.n_cont)
|
||||||
|
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
|
||||||
|
|
||||||
|
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
|
||||||
|
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
|
||||||
|
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
|
||||||
|
|
||||||
|
# Mask-out missing continuous features before aggregating across features
|
||||||
|
# (B, L, n_cont, n_embd)
|
||||||
|
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
|
||||||
|
denom = cont_mask.sum(
|
||||||
|
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
|
||||||
|
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
|
||||||
|
value_parts.append(h_cont_value)
|
||||||
|
|
||||||
|
# Explicit continuous mask embedding (fused later)
|
||||||
|
if self.cont_mask_proj is not None:
|
||||||
|
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
|
||||||
|
mask_parts.append(h_cont_mask)
|
||||||
|
|
||||||
|
if self.n_cate > 0 and cate_features is not None:
|
||||||
|
if cate_features.dim() != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"cate_features must be 3D tensor (B, L, n_cate)")
|
||||||
|
B, L, D_cate = cate_features.shape
|
||||||
|
if D_cate != self.n_cate:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
|
||||||
|
|
||||||
|
for i in range(self.n_cate):
|
||||||
|
cate_feat = cate_features[:, :, i]
|
||||||
|
cate_embd = self.cate_embds[i]
|
||||||
|
cate_mask_embd = self.cate_mask_embds[i]
|
||||||
|
|
||||||
|
cate_value = cate_embd(
|
||||||
|
torch.clamp(cate_feat, min=0))
|
||||||
|
cate_mask = (cate_feat > 0).long()
|
||||||
|
cate_mask_value = cate_mask_embd(cate_mask)
|
||||||
|
|
||||||
|
value_parts.append(cate_value)
|
||||||
|
mask_parts.append(cate_mask_value)
|
||||||
|
|
||||||
|
if not value_parts:
|
||||||
|
if cont_features is not None:
|
||||||
|
B, L = cont_features.shape[:2]
|
||||||
|
device = cont_features.device
|
||||||
|
elif cate_features is not None:
|
||||||
|
B, L = cate_features.shape[:2]
|
||||||
|
device = cate_features.device
|
||||||
|
else:
|
||||||
|
raise ValueError("No features provided to TabularEncoder.")
|
||||||
|
return torch.zeros(B, L, self.n_embd, device=device)
|
||||||
|
|
||||||
|
# Aggregate across feature groups (continuous block counts as one part;
|
||||||
|
# each categorical feature counts as one part).
|
||||||
|
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
|
||||||
|
|
||||||
|
if mask_parts:
|
||||||
|
h_mask = torch.stack(mask_parts, dim=0).mean(
|
||||||
|
dim=0) # (B, L, n_embd)
|
||||||
|
else:
|
||||||
|
h_mask = torch.zeros_like(h_value)
|
||||||
|
|
||||||
|
# Fuse by concatenation + MLP projection
|
||||||
|
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
|
||||||
|
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
|
||||||
|
h_out = self.out_ln(h_out)
|
||||||
|
return h_out
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDiscretization(nn.Module):
|
||||||
|
"""AutoDiscretization / soft-binning for continuous tabular scalars.
|
||||||
|
|
||||||
|
For each feature scalar $x$, compute a soft assignment over `n_bins`:
|
||||||
|
p = softmax(x * w + b)
|
||||||
|
Then compute the embedding as a weighted sum of learnable bin embeddings:
|
||||||
|
emb = sum_k p_k * E_k
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
Input: (N, n_features)
|
||||||
|
Output: (N, n_features, n_embd)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_features: int, n_bins: int, n_embd: int):
|
||||||
|
super().__init__()
|
||||||
|
if n_features <= 0:
|
||||||
|
raise ValueError("n_features must be > 0")
|
||||||
|
if n_bins <= 1:
|
||||||
|
raise ValueError("n_bins must be > 1")
|
||||||
|
if n_embd <= 0:
|
||||||
|
raise ValueError("n_embd must be > 0")
|
||||||
|
|
||||||
|
self.n_features = n_features
|
||||||
|
self.n_bins = n_bins
|
||||||
|
self.n_embd = n_embd
|
||||||
|
|
||||||
|
# Per-feature, per-bin affine transform to produce logits
|
||||||
|
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
|
||||||
|
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
|
||||||
|
|
||||||
|
# Learnable embeddings for each (feature, bin)
|
||||||
|
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if x.dim() != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"AutoDiscretization expects input of shape (N, n_features)")
|
||||||
|
if x.size(1) != self.n_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits: (N, n_features, n_bins)
|
||||||
|
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
|
||||||
|
self.bias.unsqueeze(0)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
# Weighted sum over bins -> (N, n_features, n_embd)
|
||||||
|
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class DelphiBERT(nn.Module):
|
||||||
|
"""
|
||||||
|
DelphiBERT model for tabular time series data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_embd (int): Embedding dimension.
|
||||||
|
n_head (int): Number of attention heads.
|
||||||
|
n_layer (int): Number of transformer blocks.
|
||||||
|
pdrop (float): Dropout probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_disease: int,
|
||||||
|
n_embd: int,
|
||||||
|
n_head: int,
|
||||||
|
n_layer: int,
|
||||||
|
n_cont: int = 0,
|
||||||
|
n_cate: int = 0,
|
||||||
|
cate_dims: Optional[List[int]] = None,
|
||||||
|
age_encoder_type: str = 'sinusoidal',
|
||||||
|
pdrop: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if n_cont > 0 or n_cate > 0:
|
||||||
|
if cate_dims is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cate_dims must be provided if n_cate > 0"
|
||||||
|
)
|
||||||
|
self.tabular_encoder = TabularEncoder(
|
||||||
|
n_embd=n_embd,
|
||||||
|
n_cont=n_cont,
|
||||||
|
n_cate=n_cate,
|
||||||
|
cate_dims=cate_dims,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tabular_encoder = None
|
||||||
|
self.vocab_size = n_disease + 4
|
||||||
|
self.n_disease = n_disease
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_head = n_head
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(
|
||||||
|
self.vocab_size, n_embd, padding_idx=0)
|
||||||
|
if age_encoder_type == 'sinusoidal':
|
||||||
|
self.age_encoder = AgeSinusoidalEncoder(n_embd)
|
||||||
|
elif age_encoder_type == 'mlp':
|
||||||
|
self.age_encoder = AgeMLPEncoder(n_embd)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported age_encoder_type: {age_encoder_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sex_embedding = nn.Embedding(2, n_embd)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Block(
|
||||||
|
n_embd=n_embd,
|
||||||
|
n_head=n_head,
|
||||||
|
pdrop=pdrop,
|
||||||
|
) for _ in range(n_layer)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
event_seq: torch.Tensor,
|
||||||
|
time_seq: torch.Tensor,
|
||||||
|
sex: torch.Tensor,
|
||||||
|
cont_seq: Optional[torch.Tensor] = None,
|
||||||
|
cate_seq: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass of DelphiBERT.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
event_seq: (B, L) long tensor of token IDs.
|
||||||
|
time_seq: (B, L) float tensor of ages/times.
|
||||||
|
sex: (B,) long tensor of sex
|
||||||
|
cont_seq: (B, Lc, n_cont) float tensor of continuous features.
|
||||||
|
cate_seq: (B, Lc, n_cate) long tensor of categorical features.
|
||||||
|
Returns:
|
||||||
|
(B, L, n_embd) output embeddings.
|
||||||
|
"""
|
||||||
|
B, L = event_seq.shape
|
||||||
|
token_emb = self.token_embedding(event_seq) # (B, L, n_embd)
|
||||||
|
age_emb = self.age_encoder(time_seq) # (B, L, n_embd)
|
||||||
|
sex_emb = self.sex_embedding(sex.unsqueeze(-1)) # (B, n_embd)
|
||||||
|
if self.tabular_encoder is not None and cont_seq is not None and cate_seq is not None:
|
||||||
|
tabular_emb = self.tabular_encoder(
|
||||||
|
cont_seq, cate_seq) # (B, L, n_embd)
|
||||||
|
mask = (event_seq == 2)
|
||||||
|
Lc = tabular_emb.size(1)
|
||||||
|
D = tabular_emb.size(2)
|
||||||
|
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
|
||||||
|
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
|
||||||
|
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
|
||||||
|
tab_inject = tabular_emb.gather(
|
||||||
|
dim=1,
|
||||||
|
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
|
||||||
|
) # (B, L, n_embd)
|
||||||
|
final_embds = torch.where(
|
||||||
|
mask.unsqueeze(-1), tab_inject, token_emb)
|
||||||
|
h = final_embds + age_emb + sex_emb
|
||||||
|
else:
|
||||||
|
h = token_emb + age_emb + sex_emb
|
||||||
|
|
||||||
|
is_padding = (event_seq == 0)
|
||||||
|
attn_mask = is_padding.view(B, 1, 1, L) # (B, 1, 1, L)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
h = block(h, attn_mask=attn_mask)
|
||||||
|
h = self.ln_f(h)
|
||||||
|
cls_output = h[:, 0, :]
|
||||||
|
return cls_output
|
||||||
26
prepare_data.R
Normal file
26
prepare_data.R
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
library(data.table)
|
||||||
|
setDTthreads(40)
|
||||||
|
library(readr)
|
||||||
|
field_id <- read.csv("field_id.txt", header = FALSE)
|
||||||
|
uid <- field_id$V1
|
||||||
|
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HHdata_221103_0512.csv"
|
||||||
|
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
|
||||||
|
all_names <- names(header_dt)
|
||||||
|
keep_names <- intersect(all_names,uid)
|
||||||
|
ukb_disease <- fread(big_path,
|
||||||
|
select = keep_names,
|
||||||
|
showProgress = TRUE)
|
||||||
|
|
||||||
|
field_id <- read.csv("field_id.txt", header = FALSE)
|
||||||
|
uid <- field_id$V1
|
||||||
|
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HH_data_220812_0512.csv"
|
||||||
|
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
|
||||||
|
all_names <- names(header_dt)
|
||||||
|
keep_names <- intersect(all_names,uid)
|
||||||
|
ukb_others <- fread(big_path,
|
||||||
|
select = keep_names,
|
||||||
|
showProgress = TRUE)
|
||||||
|
|
||||||
|
# merge disease and other data by "eid"
|
||||||
|
ukb_data <- merge(ukb_disease, ukb_others, by = "eid", all = TRUE)
|
||||||
|
fwrite(ukb_data, "ukb_data.csv")
|
||||||
216
prepare_data.py
Normal file
216
prepare_data.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import pandas as pd # Pandas for data manipulation
|
||||||
|
import tqdm # Progress bar for chunk processing
|
||||||
|
import numpy as np # Numerical operations
|
||||||
|
|
||||||
|
# CSV mapping field IDs to human-readable names
|
||||||
|
field_map_file = "field_ids_enriched.csv"
|
||||||
|
# Map original field ID -> new column name
|
||||||
|
field_dict = {}
|
||||||
|
tabular_fields = [] # List of tabular feature column names
|
||||||
|
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
|
||||||
|
next(f) # skip header line
|
||||||
|
for line in f: # Iterate through lines
|
||||||
|
parts = line.strip().split(",") # Split by CSV commas
|
||||||
|
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
|
||||||
|
# Original field identifier (e.g., "34-0.0")
|
||||||
|
field_id = parts[0]
|
||||||
|
field_name = parts[2] # Human-readable column name
|
||||||
|
field_dict[field_id] = field_name # Record the mapping
|
||||||
|
# Track as a potential tabular feature
|
||||||
|
tabular_fields.append(field_name)
|
||||||
|
# Exclude raw date parts and target columns
|
||||||
|
exclude_fields = ['year', 'month', 'Death', 'age_at_assessment']
|
||||||
|
tabular_fields = [
|
||||||
|
# Filter out excluded columns
|
||||||
|
field for field in tabular_fields if field not in exclude_fields]
|
||||||
|
|
||||||
|
# TSV mapping field IDs to ICD10-related date columns
|
||||||
|
field_to_icd_map = "icd10_codes_mod.tsv"
|
||||||
|
# Date-like variables to be converted to offsets
|
||||||
|
date_vars = []
|
||||||
|
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
|
||||||
|
for line in f: # Iterate each mapping row
|
||||||
|
parts = line.strip().split() # Split on whitespace for TSV
|
||||||
|
if len(parts) >= 6: # Guard against malformed lines
|
||||||
|
# Map field ID to the date column name
|
||||||
|
field_dict[parts[0]] = parts[5]
|
||||||
|
date_vars.append(parts[5]) # Track date column names in order
|
||||||
|
|
||||||
|
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
|
||||||
|
# Cancer diagnosis date slot j
|
||||||
|
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
|
||||||
|
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
|
||||||
|
|
||||||
|
# Number of ICD-related date columns before adding extras
|
||||||
|
len_icd = len(date_vars)
|
||||||
|
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
|
||||||
|
# Add cancer date columns
|
||||||
|
[f'cancer_date_{j}' for j in range(17)])
|
||||||
|
|
||||||
|
labels_file = "labels.csv" # File listing label codes
|
||||||
|
label_dict = {} # Map code string -> integer label id
|
||||||
|
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
|
||||||
|
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
|
||||||
|
parts = line.strip().split(' ') # Split by space
|
||||||
|
if parts and parts[0]: # Guard against empty lines
|
||||||
|
# Map code to index (0 for padding, 1 for CLS, 2 for checkup reserved)
|
||||||
|
label_dict[parts[0]] = idx + 3
|
||||||
|
|
||||||
|
event_list = [] # Accumulator for event arrays across chunks
|
||||||
|
tabular_list = [] # Accumulator for tabular feature DataFrames across chunks
|
||||||
|
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
|
||||||
|
"ukb_data.csv",
|
||||||
|
sep=',',
|
||||||
|
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
|
||||||
|
# First column (participant ID) becomes DataFrame index
|
||||||
|
index_col=0,
|
||||||
|
low_memory=False # Disable type inference optimization for consistent dtypes
|
||||||
|
)
|
||||||
|
# Iterate chunks with progress
|
||||||
|
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
|
||||||
|
# Rename columns to friendly names
|
||||||
|
ukb_chunk = ukb_chunk.rename(columns=field_dict)
|
||||||
|
# Require sex to be present
|
||||||
|
ukb_chunk.dropna(subset=['sex'], inplace=True)
|
||||||
|
|
||||||
|
# Construct date of birth from year and month (day fixed to 1)
|
||||||
|
ukb_chunk['day'] = 1
|
||||||
|
ukb_chunk['dob'] = pd.to_datetime(
|
||||||
|
# Guard against malformed dates
|
||||||
|
ukb_chunk[['year', 'month', 'day']], errors='coerce'
|
||||||
|
)
|
||||||
|
del ukb_chunk['day']
|
||||||
|
|
||||||
|
# Use only date variables that actually exist in the current chunk
|
||||||
|
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
|
||||||
|
|
||||||
|
# Convert date-like columns to datetime and compute day offsets from dob
|
||||||
|
if present_date_vars:
|
||||||
|
date_cols = ukb_chunk[present_date_vars].apply(
|
||||||
|
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
|
||||||
|
)
|
||||||
|
date_cols_days = date_cols.sub(
|
||||||
|
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
|
||||||
|
ukb_chunk[present_date_vars] = date_cols_days.apply(
|
||||||
|
lambda x: x.dt.days) # Store days since dob
|
||||||
|
|
||||||
|
# Append tabular features (use only columns that exist)
|
||||||
|
present_tabular_fields = [
|
||||||
|
c for c in tabular_fields if c in ukb_chunk.columns]
|
||||||
|
tabular_list.append(ukb_chunk[present_tabular_fields].copy())
|
||||||
|
|
||||||
|
# Process disease events from ICD10-related date columns
|
||||||
|
# Take ICD date cols plus 'Death' if present by order
|
||||||
|
icd10_cols = present_date_vars[:len_icd + 1]
|
||||||
|
# Melt to long form: participant id, event code (column name), and days offset
|
||||||
|
melted_df = ukb_chunk.reset_index().melt(
|
||||||
|
id_vars=['eid'],
|
||||||
|
value_vars=icd10_cols,
|
||||||
|
var_name='event_code',
|
||||||
|
value_name='days',
|
||||||
|
)
|
||||||
|
# Require non-missing day offsets
|
||||||
|
melted_df.dropna(subset=['days'], inplace=True)
|
||||||
|
if not melted_df.empty:
|
||||||
|
melted_df['label'] = melted_df['event_code'].map(
|
||||||
|
label_dict) # Map event code to numeric label
|
||||||
|
# Fix: ensure labels exist before int cast
|
||||||
|
melted_df.dropna(subset=['label'], inplace=True)
|
||||||
|
if not melted_df.empty:
|
||||||
|
event_list.append(
|
||||||
|
melted_df[['eid', 'days', 'label']]
|
||||||
|
.astype(int) # Safe now since label and days are non-null
|
||||||
|
.to_numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimized cancer processing without wide_to_long
|
||||||
|
cancer_frames = []
|
||||||
|
for j in range(17):
|
||||||
|
d_col = f'cancer_date_{j}'
|
||||||
|
t_col = f'cancer_type_{j}'
|
||||||
|
if d_col in ukb_chunk.columns and t_col in ukb_chunk.columns:
|
||||||
|
# Filter rows where both date and type are present
|
||||||
|
mask = ukb_chunk[d_col].notna() & ukb_chunk[t_col].notna()
|
||||||
|
if mask.any():
|
||||||
|
subset_idx = ukb_chunk.index[mask]
|
||||||
|
subset_days = ukb_chunk.loc[mask, d_col]
|
||||||
|
subset_type = ukb_chunk.loc[mask, t_col]
|
||||||
|
|
||||||
|
# Map cancer type to label
|
||||||
|
# Use first 3 chars
|
||||||
|
cancer_codes = subset_type.str.slice(0, 3)
|
||||||
|
labels = cancer_codes.map(label_dict)
|
||||||
|
|
||||||
|
# Filter valid labels
|
||||||
|
valid_label_mask = labels.notna()
|
||||||
|
if valid_label_mask.any():
|
||||||
|
# Create array: eid, days, label
|
||||||
|
# Ensure types are correct for numpy
|
||||||
|
c_eids = subset_idx[valid_label_mask].values
|
||||||
|
c_days = subset_days[valid_label_mask].values
|
||||||
|
c_labels = labels[valid_label_mask].values
|
||||||
|
|
||||||
|
# Stack
|
||||||
|
chunk_cancer_data = np.column_stack(
|
||||||
|
(c_eids, c_days, c_labels))
|
||||||
|
cancer_frames.append(chunk_cancer_data)
|
||||||
|
|
||||||
|
if cancer_frames:
|
||||||
|
event_list.append(np.vstack(cancer_frames))
|
||||||
|
|
||||||
|
# Combine tabular chunks
|
||||||
|
final_tabular = pd.concat(tabular_list, axis=0, ignore_index=False)
|
||||||
|
final_tabular.index.name = 'eid' # Ensure index named consistently
|
||||||
|
data = np.vstack(event_list) # Stack all event arrays into one
|
||||||
|
|
||||||
|
# Sort by participant then day
|
||||||
|
data = data[np.lexsort((data[:, 1], data[:, 0]))]
|
||||||
|
|
||||||
|
# Keep only events with non-negative day offsets
|
||||||
|
data = data[data[:, 1] >= 0]
|
||||||
|
|
||||||
|
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
|
||||||
|
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
|
||||||
|
|
||||||
|
# Store compactly using unsigned 32-bit integers
|
||||||
|
data = data.astype(np.uint32)
|
||||||
|
|
||||||
|
# Select eid in both data and tabular
|
||||||
|
valid_eids = np.intersect1d(data[:, 0], final_tabular.index)
|
||||||
|
data = data[np.isin(data[:, 0], valid_eids)]
|
||||||
|
final_tabular = final_tabular.loc[valid_eids]
|
||||||
|
final_tabular = final_tabular.convert_dtypes()
|
||||||
|
|
||||||
|
# Save [eid, sex, date_of_assessment] for basic info
|
||||||
|
basic_info = final_tabular[['sex', 'date_of_assessment']]
|
||||||
|
basic_info.to_csv("ukb_basic_info.csv")
|
||||||
|
|
||||||
|
# Drop sex and date_of_assessment from tabular features
|
||||||
|
final_tabular = final_tabular.drop(columns=['sex', 'date_of_assessment'])
|
||||||
|
|
||||||
|
# Process categorical columns in tabular features
|
||||||
|
# If a column is integer type with few unique values, treat as categorical. For each integer column:
|
||||||
|
# Count unique values (exclude NaN, and negative values if any) as C, set NaN or negative to 0, remap original values to [1..C].
|
||||||
|
for col in final_tabular.select_dtypes(include=['Int64', 'int64']).columns:
|
||||||
|
# Get unique values efficiently
|
||||||
|
series = final_tabular[col]
|
||||||
|
unique_vals = series.dropna().unique()
|
||||||
|
|
||||||
|
# Filter negatives from unique values
|
||||||
|
valid_vals = sorted([v for v in unique_vals if v >= 0])
|
||||||
|
|
||||||
|
if len(valid_vals) <= 10: # Threshold for categorical
|
||||||
|
# Create mapping
|
||||||
|
val_map = {val: idx + 1 for idx, val in enumerate(valid_vals)}
|
||||||
|
|
||||||
|
# Map values. Values not in val_map (negatives, NaNs) become NaN
|
||||||
|
mapped_col = series.map(val_map)
|
||||||
|
|
||||||
|
# Fill NaN with 0 and convert to uint32
|
||||||
|
final_tabular[col] = mapped_col.fillna(0).astype(np.uint32)
|
||||||
|
|
||||||
|
# Save processed tabular features
|
||||||
|
final_tabular.to_csv("ukb_table.csv")
|
||||||
|
|
||||||
|
# Save event data
|
||||||
|
np.save("ukb_event_data.npy", data)
|
||||||
Reference in New Issue
Block a user