Implement DelphiBERT model and data preparation scripts for tabular time series analysis

- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features.
- Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity.
- Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
This commit is contained in:
2026-01-20 23:33:30 +08:00
parent bd1ddf936a
commit f729f05190
8 changed files with 3411 additions and 0 deletions

55
age_encoder.py Normal file
View File

@@ -0,0 +1,55 @@
import torch
import torch.nn as nn
class AgeSinusoidalEncoder(nn.Module):
"""
Sinusoidal encoder for age.
Args:
n_embd (int): Embedding dimension. Must be even.
"""
def __init__(self, n_embd: int):
super().__init__()
if n_embd % 2 != 0:
raise ValueError("n_embd must be even for sinusoidal encoding.")
self.n_embd = n_embd
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.n_embd)
self.register_buffer('divisor', divisor)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
t_years = ages / 365.25
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
# Interleave cos and sin along the last dimension
output = torch.zeros(
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
output[:, :, 0::2] = torch.cos(args)
output[:, :, 1::2] = torch.sin(args)
return output
class AgeMLPEncoder(nn.Module):
"""
MLP encoder for age, using sinusoidal encoding as input.
Args:
n_embd (int): Embedding dimension.
"""
def __init__(self, n_embd: int):
super().__init__()
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
x = self.sin_encoder(ages)
output = self.mlp(x)
return output

112
backbones.py Normal file
View File

@@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class SelfAttention(nn.Module):
"""
Multi-head self-attention mechanism.
Args:
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
attn_pdrop (float): Attention dropout probability.
"""
def __init__(
self,
n_embd: int,
n_head: int,
attn_pdrop: float = 0.1,
):
super().__init__()
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.n_head = n_head
self.head_dim = n_embd // n_head
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_pdrop = attn_pdrop
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, L, D = x.shape
qkv = self.qkv_proj(x) # (B, L, 3D)
q, k, v = qkv.chunk(3, dim=-1)
def reshape_heads(t):
# (B, H, L, d)
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
q = reshape_heads(q)
k = reshape_heads(k)
v = reshape_heads(v)
dropout_p = self.attn_pdrop if self.training else 0.0
attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=dropout_p,
attn_mask=attn_mask,
) # (B, H, L, d)
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
return self.o_proj(attn)
class Block(nn.Module):
"""
Transformer block consisting of self-attention and MLP.
Args:
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability.
"""
def __init__(
self,
n_embd: int,
n_head: int,
pdrop: float = 0.0,
):
super().__init__()
attn_pdrop = pdrop
self.norm_1 = nn.LayerNorm(n_embd)
self.attn = SelfAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
)
self.norm_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc=nn.Linear(n_embd, 4 * n_embd),
c_proj=nn.Linear(4 * n_embd, n_embd),
act=nn.GELU(),
dropout=nn.Dropout(pdrop),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(
m.c_proj(m.act(m.c_fc(x))))
self.resid_dropout = nn.Dropout(pdrop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Attention
h = self.norm_1(x)
h = self.attn(h, attn_mask=attn_mask)
x = x + self.resid_dropout(h)
# MLP
h = self.norm_2(x)
h = self.mlpf(h)
x = x + self.resid_dropout(h)
return x

240
dataset.py Normal file
View File

@@ -0,0 +1,240 @@
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import Any, Dict, List, Sequence, Tuple
class HealthDataset(Dataset):
"""
Dataset for health records.
Args:
data_prefix (str): Prefix for data files.
covariate_list (List[str] | None): List of covariates to include.
"""
def __init__(
self,
data_prefix: str,
max_len: int = 64,
covariate_list: List[str] | None = None,
):
event_data = np.load(f"{data_prefix}_event_data.npy")
basic_info = pd.read_csv(
f"{data_prefix}_basic_info.csv", index_col='eid')
cont_cols = []
cate_cols = []
self.cate_dims = []
self.n_cont = 0
self.n_cate = 0
if covariate_list is not None:
tabular_data = pd.read_csv(
f"{data_prefix}_table.csv", index_col='eid')[covariate_list]
tabular_data = tabular_data.convert_dtypes()
for col in tabular_data.columns:
if pd.api.types.is_float_dtype(tabular_data[col]):
cont_cols.append(col)
elif pd.api.types.is_integer_dtype(tabular_data[col]):
series = tabular_data[col]
unique_vals = series.dropna().unique()
if len(unique_vals) > 11:
cont_cols.append(col)
else:
cate_cols.append(col)
self.cate_dims.append(int(series.max()) + 1)
self.cont_features = tabular_data[cont_cols].to_numpy(
dtype=np.float32).copy()
self.cate_features = tabular_data[cate_cols].to_numpy(
dtype=np.int64).copy()
self.n_cont = self.cont_features.shape[1]
self.n_cate = self.cate_features.shape[1]
self.cont_features = torch.from_numpy(self.cont_features)
self.cate_features = torch.from_numpy(self.cate_features)
else:
tabular_data = None
self.cont_features = None
self.cate_features = None
patient_events = defaultdict(list)
vocab_size = 0
for patient_id, time_in_days, event_code in event_data:
patient_events[patient_id].append((time_in_days, event_code))
if event_code > vocab_size:
vocab_size = event_code
self.n_event = vocab_size - 2 # Exclude padding and CLS codes
self.n_disease = self.n_event - 1 # Exclude death code
self.max_len = max_len
self.death_token = vocab_size
self.basic_info = basic_info.convert_dtypes()
self.patient_ids = self.basic_info.index.tolist()
self.patient_events = dict(patient_events)
for patient_id, records in self.patient_events.items():
records.sort(key=lambda x: x[0])
self._doa = self.basic_info.loc[
self.patient_ids, 'date_of_assessment'
].to_numpy(dtype=np.float32)
self._sex = self.basic_info.loc[
self.patient_ids, 'sex'
].to_numpy(dtype=np.int64)
def __len__(self) -> int:
return len(self.patient_ids)
def __getitem__(self, idx: int):
patient_id = self.patient_ids[idx]
records = list(self.patient_events.get(patient_id, []))
if self.cont_features is not None or self.cate_features is not None:
doa = float(self._doa[idx])
# insert checkup event at date of assessment
records.append((doa, 2)) # 1 is checkup token
records.sort(key=lambda x: x[0])
n_events = len(records)
has_died = (records[-1][1] ==
self.death_token) if n_events > 0 else False
if has_died:
valid_max_split = n_events - 1
else:
valid_max_split = n_events
if valid_max_split < 1:
split_idx = 0
else:
split_idx = np.random.randint(1, valid_max_split + 1)
past_records = records[:split_idx]
future_records = records[split_idx:]
if len(past_records) > 0:
current_time = past_records[-1][0]
else:
current_time = 0
if len(past_records) > self.max_len:
past_records = past_records[-self.max_len:]
event_seq = [1] + [item[1] for item in past_records] # 1 is CLS token
time_seq = [current_time] + [item[0] for item in past_records]
event_tensor = torch.tensor(event_seq, dtype=torch.long)
time_tensor = torch.tensor(time_seq, dtype=torch.float)
labels = torch.zeros(self.n_disease, 2, dtype=torch.float32)
global_censoring_time = records[-1][0] - current_time
labels[:, 1] = global_censoring_time
loss_mask = torch.ones(self.n_disease, dtype=torch.float32)
death_time = 0.0
for t, token in past_records:
if token > 2:
disease_idx = token - 3
if 0 <= disease_idx < self.n_disease:
loss_mask[disease_idx] = 0.0
for t, token in future_records:
rel_time = t - current_time
if rel_time < 1e-6:
rel_time = 1e-6
if token == self.death_token:
death_time = rel_time
break
if token > 2:
disease_idx = token - 3
if 0 <= disease_idx < self.n_disease:
if labels[disease_idx, 0] == 0.0:
labels[disease_idx, 0] = 1.0
labels[disease_idx, 1] = rel_time
if self.cont_features is not None:
# 提取该病人的静态特征向量
cur_cont = self.cont_features[idx].to(
dtype=torch.float) # (n_cont,)
cur_cate = self.cate_features[idx].to(
dtype=torch.long) # (n_cate,)
else:
cur_cont = None
cur_cate = None
sex = int(self._sex[idx])
return {
'event_seq': event_tensor,
'time_seq': time_tensor,
'labels': labels,
'loss_mask': loss_mask,
'sex': sex,
'death_time': death_time,
'global_censoring_time': global_censoring_time,
'cont_features': cur_cont,
'cate_features': cur_cate,
}
def health_collate_fn(batch):
# 1) pad variable-length sequences
event_seqs = [item['event_seq'] for item in batch] # 1D LongTensor
time_seqs = [item['time_seq'] for item in batch] # 1D FloatTensor
event_seq_padded = pad_sequence(
event_seqs, batch_first=True, padding_value=0).long()
time_seq_padded = pad_sequence(
time_seqs, batch_first=True, padding_value=0.0).float()
attn_mask = (event_seq_padded != 0) # True for real tokens
# 2) stack fixed-size tensors
labels = torch.stack([item['labels'] for item in batch],
dim=0).float() # (B, n_disease, 2)
loss_mask = torch.stack([item['loss_mask']
# (B, n_disease)
for item in batch], dim=0).float()
# 3) scalar fields -> tensor
sex = torch.tensor([int(item['sex'])
for item in batch], dtype=torch.long) # (B,)
death_time = torch.tensor([float(item['death_time'])
for item in batch], dtype=torch.float32) # (B,)
global_censoring_time = torch.tensor(
[float(item['global_censoring_time']) for item in batch],
dtype=torch.float32
) # (B,)
# 4) optional static tabular features
cont_list = [item['cont_features'] for item in batch]
cate_list = [item['cate_features'] for item in batch]
has_cont = any(x is not None for x in cont_list)
has_cate = any(x is not None for x in cate_list)
if has_cont and any(x is None for x in cont_list):
raise ValueError("Mixed None/non-None cont_features in batch.")
if has_cate and any(x is None for x in cate_list):
raise ValueError("Mixed None/non-None cate_features in batch.")
cont_features = torch.stack(cont_list, dim=0).float(
) if has_cont else None # (B, n_cont)
cate_features = torch.stack(cate_list, dim=0).long(
) if has_cate else None # (B, n_cate)
return {
'event_seq': event_seq_padded,
'time_seq': time_seq_padded,
'attn_mask': attn_mask,
'labels': labels,
'loss_mask': loss_mask,
'sex': sex,
'death_time': death_time,
'global_censoring_time': global_censoring_time,
'cont_features': cont_features,
'cate_features': cate_features,
}

1129
icd10_codes_mod.tsv Normal file

File diff suppressed because it is too large Load Diff

1257
labels.csv Normal file

File diff suppressed because it is too large Load Diff

376
model.py Normal file
View File

@@ -0,0 +1,376 @@
import numpy as np
from typing import Optional, List
from backbones import Block
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
import torch.nn.functional as F
import torch.nn as nn
import torch
class TabularEncoder(nn.Module):
"""
Encoder for tabular features (continuous and categorical).
Args:
n_embd (int): Embedding dimension.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
n_bins (int): Number of soft bins for continuous AutoDiscretization.
"""
def __init__(
self,
n_embd: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
n_bins: int = 16,
):
super().__init__()
self.n_embd = n_embd
self.n_cont = n_cont
self.n_cate = n_cate
# Continuous feature path
# - BatchNorm on raw (NaN-filled) continuous values
# - AutoDiscretization (soft binning) per feature
if n_cont > 0:
self.cont_bn = nn.BatchNorm1d(n_cont)
self.cont_discretizer = AutoDiscretization(
n_features=n_cont,
n_bins=n_bins,
n_embd=n_embd,
)
else:
self.cont_bn = None
self.cont_discretizer = None
if n_cate > 0:
assert len(cate_dims) == n_cate, \
"Length of cate_dims must match n_cate"
self.cate_embds = nn.ModuleList([
nn.Embedding(dim, n_embd) for dim in cate_dims
])
self.cate_mask_embds = nn.ModuleList([
nn.Embedding(2, n_embd) for _ in range(n_cate)
])
else:
self.cate_embds = None
self.cate_mask_embds = None
self.cont_mask_proj = (
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
)
# Fuse aggregated value + aggregated mask via MLP
self.fuse_mlp = nn.Sequential(
nn.Linear(2 * n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, n_embd),
)
self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor],
) -> torch.Tensor:
"""Encode tabular features into a per-timestep embedding.
Inputs:
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
Returns:
(B, L, n_embd) encoded embedding.
"""
if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError(
"TabularEncoder received no features but cannot infer (B, L)."
)
return torch.zeros(B, L, self.n_embd, device=device)
value_parts: List[torch.Tensor] = []
mask_parts: List[torch.Tensor] = []
if self.n_cont > 0 and cont_features is not None:
if cont_features.dim() != 3:
raise ValueError(
"cont_features must be 3D tensor (B, L, n_cont)")
B, L, D_cont = cont_features.shape
if D_cont != self.n_cont:
raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
# Missingness mask: 1 for valid, 0 for missing
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
cont_filled = torch.nan_to_num(
cont_features, nan=0.0) # (B, L, n_cont)
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
cont_flat = cont_filled.reshape(-1, self.n_cont)
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
# Mask-out missing continuous features before aggregating across features
# (B, L, n_cont, n_embd)
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
denom = cont_mask.sum(
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
value_parts.append(h_cont_value)
# Explicit continuous mask embedding (fused later)
if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None:
if cate_features.dim() != 3:
raise ValueError(
"cate_features must be 3D tensor (B, L, n_cate)")
B, L, D_cate = cate_features.shape
if D_cate != self.n_cate:
raise ValueError(
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
for i in range(self.n_cate):
cate_feat = cate_features[:, :, i]
cate_embd = self.cate_embds[i]
cate_mask_embd = self.cate_mask_embds[i]
cate_value = cate_embd(
torch.clamp(cate_feat, min=0))
cate_mask = (cate_feat > 0).long()
cate_mask_value = cate_mask_embd(cate_mask)
value_parts.append(cate_value)
mask_parts.append(cate_mask_value)
if not value_parts:
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device)
# Aggregate across feature groups (continuous block counts as one part;
# each categorical feature counts as one part).
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
if mask_parts:
h_mask = torch.stack(mask_parts, dim=0).mean(
dim=0) # (B, L, n_embd)
else:
h_mask = torch.zeros_like(h_value)
# Fuse by concatenation + MLP projection
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
h_out = self.out_ln(h_out)
return h_out
class AutoDiscretization(nn.Module):
"""AutoDiscretization / soft-binning for continuous tabular scalars.
For each feature scalar $x$, compute a soft assignment over `n_bins`:
p = softmax(x * w + b)
Then compute the embedding as a weighted sum of learnable bin embeddings:
emb = sum_k p_k * E_k
Shapes:
Input: (N, n_features)
Output: (N, n_features, n_embd)
"""
def __init__(self, n_features: int, n_bins: int, n_embd: int):
super().__init__()
if n_features <= 0:
raise ValueError("n_features must be > 0")
if n_bins <= 1:
raise ValueError("n_bins must be > 1")
if n_embd <= 0:
raise ValueError("n_embd must be > 0")
self.n_features = n_features
self.n_bins = n_bins
self.n_embd = n_embd
# Per-feature, per-bin affine transform to produce logits
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
# Learnable embeddings for each (feature, bin)
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.bias)
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 2:
raise ValueError(
"AutoDiscretization expects input of shape (N, n_features)")
if x.size(1) != self.n_features:
raise ValueError(
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
)
# logits: (N, n_features, n_bins)
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
self.bias.unsqueeze(0)
probs = torch.softmax(logits, dim=-1)
# Weighted sum over bins -> (N, n_features, n_embd)
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
return emb
class DelphiBERT(nn.Module):
"""
DelphiBERT model for tabular time series data.
Args:
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
n_layer (int): Number of transformer blocks.
pdrop (float): Dropout probability.
"""
def __init__(
self,
n_disease: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int = 0,
n_cate: int = 0,
cate_dims: Optional[List[int]] = None,
age_encoder_type: str = 'sinusoidal',
pdrop: float = 0.0,
):
super().__init__()
if n_cont > 0 or n_cate > 0:
if cate_dims is None:
raise ValueError(
"cate_dims must be provided if n_cate > 0"
)
self.tabular_encoder = TabularEncoder(
n_embd=n_embd,
n_cont=n_cont,
n_cate=n_cate,
cate_dims=cate_dims,
)
else:
self.tabular_encoder = None
self.vocab_size = n_disease + 4
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
if age_encoder_type == 'sinusoidal':
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == 'mlp':
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}"
)
self.sex_embedding = nn.Embedding(2, n_embd)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
def forward(
self,
event_seq: torch.Tensor,
time_seq: torch.Tensor,
sex: torch.Tensor,
cont_seq: Optional[torch.Tensor] = None,
cate_seq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass of DelphiBERT.
Inputs:
event_seq: (B, L) long tensor of token IDs.
time_seq: (B, L) float tensor of ages/times.
sex: (B,) long tensor of sex
cont_seq: (B, Lc, n_cont) float tensor of continuous features.
cate_seq: (B, Lc, n_cate) long tensor of categorical features.
Returns:
(B, L, n_embd) output embeddings.
"""
B, L = event_seq.shape
token_emb = self.token_embedding(event_seq) # (B, L, n_embd)
age_emb = self.age_encoder(time_seq) # (B, L, n_embd)
sex_emb = self.sex_embedding(sex.unsqueeze(-1)) # (B, n_embd)
if self.tabular_encoder is not None and cont_seq is not None and cate_seq is not None:
tabular_emb = self.tabular_encoder(
cont_seq, cate_seq) # (B, L, n_embd)
mask = (event_seq == 2)
Lc = tabular_emb.size(1)
D = tabular_emb.size(2)
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
tab_inject = tabular_emb.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
) # (B, L, n_embd)
final_embds = torch.where(
mask.unsqueeze(-1), tab_inject, token_emb)
h = final_embds + age_emb + sex_emb
else:
h = token_emb + age_emb + sex_emb
is_padding = (event_seq == 0)
attn_mask = is_padding.view(B, 1, 1, L) # (B, 1, 1, L)
for block in self.blocks:
h = block(h, attn_mask=attn_mask)
h = self.ln_f(h)
cls_output = h[:, 0, :]
return cls_output

26
prepare_data.R Normal file
View File

@@ -0,0 +1,26 @@
library(data.table)
setDTthreads(40)
library(readr)
field_id <- read.csv("field_id.txt", header = FALSE)
uid <- field_id$V1
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HHdata_221103_0512.csv"
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
all_names <- names(header_dt)
keep_names <- intersect(all_names,uid)
ukb_disease <- fread(big_path,
select = keep_names,
showProgress = TRUE)
field_id <- read.csv("field_id.txt", header = FALSE)
uid <- field_id$V1
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HH_data_220812_0512.csv"
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
all_names <- names(header_dt)
keep_names <- intersect(all_names,uid)
ukb_others <- fread(big_path,
select = keep_names,
showProgress = TRUE)
# merge disease and other data by "eid"
ukb_data <- merge(ukb_disease, ukb_others, by = "eid", all = TRUE)
fwrite(ukb_data, "ukb_data.csv")

216
prepare_data.py Normal file
View File

@@ -0,0 +1,216 @@
import pandas as pd # Pandas for data manipulation
import tqdm # Progress bar for chunk processing
import numpy as np # Numerical operations
# CSV mapping field IDs to human-readable names
field_map_file = "field_ids_enriched.csv"
# Map original field ID -> new column name
field_dict = {}
tabular_fields = [] # List of tabular feature column names
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
next(f) # skip header line
for line in f: # Iterate through lines
parts = line.strip().split(",") # Split by CSV commas
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
# Original field identifier (e.g., "34-0.0")
field_id = parts[0]
field_name = parts[2] # Human-readable column name
field_dict[field_id] = field_name # Record the mapping
# Track as a potential tabular feature
tabular_fields.append(field_name)
# Exclude raw date parts and target columns
exclude_fields = ['year', 'month', 'Death', 'age_at_assessment']
tabular_fields = [
# Filter out excluded columns
field for field in tabular_fields if field not in exclude_fields]
# TSV mapping field IDs to ICD10-related date columns
field_to_icd_map = "icd10_codes_mod.tsv"
# Date-like variables to be converted to offsets
date_vars = []
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
for line in f: # Iterate each mapping row
parts = line.strip().split() # Split on whitespace for TSV
if len(parts) >= 6: # Guard against malformed lines
# Map field ID to the date column name
field_dict[parts[0]] = parts[5]
date_vars.append(parts[5]) # Track date column names in order
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
# Cancer diagnosis date slot j
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
# Number of ICD-related date columns before adding extras
len_icd = len(date_vars)
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
# Add cancer date columns
[f'cancer_date_{j}' for j in range(17)])
labels_file = "labels.csv" # File listing label codes
label_dict = {} # Map code string -> integer label id
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
parts = line.strip().split(' ') # Split by space
if parts and parts[0]: # Guard against empty lines
# Map code to index (0 for padding, 1 for CLS, 2 for checkup reserved)
label_dict[parts[0]] = idx + 3
event_list = [] # Accumulator for event arrays across chunks
tabular_list = [] # Accumulator for tabular feature DataFrames across chunks
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
"ukb_data.csv",
sep=',',
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
# First column (participant ID) becomes DataFrame index
index_col=0,
low_memory=False # Disable type inference optimization for consistent dtypes
)
# Iterate chunks with progress
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
# Rename columns to friendly names
ukb_chunk = ukb_chunk.rename(columns=field_dict)
# Require sex to be present
ukb_chunk.dropna(subset=['sex'], inplace=True)
# Construct date of birth from year and month (day fixed to 1)
ukb_chunk['day'] = 1
ukb_chunk['dob'] = pd.to_datetime(
# Guard against malformed dates
ukb_chunk[['year', 'month', 'day']], errors='coerce'
)
del ukb_chunk['day']
# Use only date variables that actually exist in the current chunk
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
# Convert date-like columns to datetime and compute day offsets from dob
if present_date_vars:
date_cols = ukb_chunk[present_date_vars].apply(
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
)
date_cols_days = date_cols.sub(
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
ukb_chunk[present_date_vars] = date_cols_days.apply(
lambda x: x.dt.days) # Store days since dob
# Append tabular features (use only columns that exist)
present_tabular_fields = [
c for c in tabular_fields if c in ukb_chunk.columns]
tabular_list.append(ukb_chunk[present_tabular_fields].copy())
# Process disease events from ICD10-related date columns
# Take ICD date cols plus 'Death' if present by order
icd10_cols = present_date_vars[:len_icd + 1]
# Melt to long form: participant id, event code (column name), and days offset
melted_df = ukb_chunk.reset_index().melt(
id_vars=['eid'],
value_vars=icd10_cols,
var_name='event_code',
value_name='days',
)
# Require non-missing day offsets
melted_df.dropna(subset=['days'], inplace=True)
if not melted_df.empty:
melted_df['label'] = melted_df['event_code'].map(
label_dict) # Map event code to numeric label
# Fix: ensure labels exist before int cast
melted_df.dropna(subset=['label'], inplace=True)
if not melted_df.empty:
event_list.append(
melted_df[['eid', 'days', 'label']]
.astype(int) # Safe now since label and days are non-null
.to_numpy()
)
# Optimized cancer processing without wide_to_long
cancer_frames = []
for j in range(17):
d_col = f'cancer_date_{j}'
t_col = f'cancer_type_{j}'
if d_col in ukb_chunk.columns and t_col in ukb_chunk.columns:
# Filter rows where both date and type are present
mask = ukb_chunk[d_col].notna() & ukb_chunk[t_col].notna()
if mask.any():
subset_idx = ukb_chunk.index[mask]
subset_days = ukb_chunk.loc[mask, d_col]
subset_type = ukb_chunk.loc[mask, t_col]
# Map cancer type to label
# Use first 3 chars
cancer_codes = subset_type.str.slice(0, 3)
labels = cancer_codes.map(label_dict)
# Filter valid labels
valid_label_mask = labels.notna()
if valid_label_mask.any():
# Create array: eid, days, label
# Ensure types are correct for numpy
c_eids = subset_idx[valid_label_mask].values
c_days = subset_days[valid_label_mask].values
c_labels = labels[valid_label_mask].values
# Stack
chunk_cancer_data = np.column_stack(
(c_eids, c_days, c_labels))
cancer_frames.append(chunk_cancer_data)
if cancer_frames:
event_list.append(np.vstack(cancer_frames))
# Combine tabular chunks
final_tabular = pd.concat(tabular_list, axis=0, ignore_index=False)
final_tabular.index.name = 'eid' # Ensure index named consistently
data = np.vstack(event_list) # Stack all event arrays into one
# Sort by participant then day
data = data[np.lexsort((data[:, 1], data[:, 0]))]
# Keep only events with non-negative day offsets
data = data[data[:, 1] >= 0]
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
# Store compactly using unsigned 32-bit integers
data = data.astype(np.uint32)
# Select eid in both data and tabular
valid_eids = np.intersect1d(data[:, 0], final_tabular.index)
data = data[np.isin(data[:, 0], valid_eids)]
final_tabular = final_tabular.loc[valid_eids]
final_tabular = final_tabular.convert_dtypes()
# Save [eid, sex, date_of_assessment] for basic info
basic_info = final_tabular[['sex', 'date_of_assessment']]
basic_info.to_csv("ukb_basic_info.csv")
# Drop sex and date_of_assessment from tabular features
final_tabular = final_tabular.drop(columns=['sex', 'date_of_assessment'])
# Process categorical columns in tabular features
# If a column is integer type with few unique values, treat as categorical. For each integer column:
# Count unique values (exclude NaN, and negative values if any) as C, set NaN or negative to 0, remap original values to [1..C].
for col in final_tabular.select_dtypes(include=['Int64', 'int64']).columns:
# Get unique values efficiently
series = final_tabular[col]
unique_vals = series.dropna().unique()
# Filter negatives from unique values
valid_vals = sorted([v for v in unique_vals if v >= 0])
if len(valid_vals) <= 10: # Threshold for categorical
# Create mapping
val_map = {val: idx + 1 for idx, val in enumerate(valid_vals)}
# Map values. Values not in val_map (negatives, NaNs) become NaN
mapped_col = series.map(val_map)
# Fill NaN with 0 and convert to uint32
final_tabular[col] = mapped_col.fillna(0).astype(np.uint32)
# Save processed tabular features
final_tabular.to_csv("ukb_table.csv")
# Save event data
np.save("ukb_event_data.npy", data)