- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features. - Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity. - Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
377 lines
13 KiB
Python
377 lines
13 KiB
Python
import numpy as np
|
|
from typing import Optional, List
|
|
from backbones import Block
|
|
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class TabularEncoder(nn.Module):
|
|
"""
|
|
Encoder for tabular features (continuous and categorical).
|
|
|
|
Args:
|
|
n_embd (int): Embedding dimension.
|
|
n_cont (int): Number of continuous features.
|
|
n_cate (int): Number of categorical features.
|
|
cate_dims (List[int]): List of dimensions for each categorical feature.
|
|
n_bins (int): Number of soft bins for continuous AutoDiscretization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
n_cont: int,
|
|
n_cate: int,
|
|
cate_dims: List[int],
|
|
n_bins: int = 16,
|
|
):
|
|
super().__init__()
|
|
self.n_embd = n_embd
|
|
self.n_cont = n_cont
|
|
self.n_cate = n_cate
|
|
|
|
# Continuous feature path
|
|
# - BatchNorm on raw (NaN-filled) continuous values
|
|
# - AutoDiscretization (soft binning) per feature
|
|
if n_cont > 0:
|
|
self.cont_bn = nn.BatchNorm1d(n_cont)
|
|
self.cont_discretizer = AutoDiscretization(
|
|
n_features=n_cont,
|
|
n_bins=n_bins,
|
|
n_embd=n_embd,
|
|
)
|
|
else:
|
|
self.cont_bn = None
|
|
self.cont_discretizer = None
|
|
|
|
if n_cate > 0:
|
|
assert len(cate_dims) == n_cate, \
|
|
"Length of cate_dims must match n_cate"
|
|
self.cate_embds = nn.ModuleList([
|
|
nn.Embedding(dim, n_embd) for dim in cate_dims
|
|
])
|
|
self.cate_mask_embds = nn.ModuleList([
|
|
nn.Embedding(2, n_embd) for _ in range(n_cate)
|
|
])
|
|
else:
|
|
self.cate_embds = None
|
|
self.cate_mask_embds = None
|
|
|
|
self.cont_mask_proj = (
|
|
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
|
|
)
|
|
|
|
# Fuse aggregated value + aggregated mask via MLP
|
|
self.fuse_mlp = nn.Sequential(
|
|
nn.Linear(2 * n_embd, 2 * n_embd),
|
|
nn.GELU(),
|
|
nn.Linear(2 * n_embd, n_embd),
|
|
)
|
|
|
|
self.apply(self._init_weights)
|
|
self.out_ln = nn.LayerNorm(n_embd)
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
def forward(
|
|
self,
|
|
cont_features: Optional[torch.Tensor],
|
|
cate_features: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Encode tabular features into a per-timestep embedding.
|
|
|
|
Inputs:
|
|
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
|
|
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
|
|
|
|
Returns:
|
|
(B, L, n_embd) encoded embedding.
|
|
"""
|
|
|
|
if self.n_cont == 0 and self.n_cate == 0:
|
|
# infer (B, L) from whichever input is not None
|
|
if cont_features is not None:
|
|
B, L = cont_features.shape[:2]
|
|
device = cont_features.device
|
|
elif cate_features is not None:
|
|
B, L = cate_features.shape[:2]
|
|
device = cate_features.device
|
|
else:
|
|
raise ValueError(
|
|
"TabularEncoder received no features but cannot infer (B, L)."
|
|
)
|
|
return torch.zeros(B, L, self.n_embd, device=device)
|
|
|
|
value_parts: List[torch.Tensor] = []
|
|
mask_parts: List[torch.Tensor] = []
|
|
|
|
if self.n_cont > 0 and cont_features is not None:
|
|
if cont_features.dim() != 3:
|
|
raise ValueError(
|
|
"cont_features must be 3D tensor (B, L, n_cont)")
|
|
B, L, D_cont = cont_features.shape
|
|
if D_cont != self.n_cont:
|
|
raise ValueError(
|
|
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
|
|
|
|
# Missingness mask: 1 for valid, 0 for missing
|
|
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
|
|
|
|
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
|
|
cont_filled = torch.nan_to_num(
|
|
cont_features, nan=0.0) # (B, L, n_cont)
|
|
|
|
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
|
|
cont_flat = cont_filled.reshape(-1, self.n_cont)
|
|
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
|
|
|
|
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
|
|
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
|
|
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
|
|
|
|
# Mask-out missing continuous features before aggregating across features
|
|
# (B, L, n_cont, n_embd)
|
|
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
|
|
denom = cont_mask.sum(
|
|
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
|
|
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
|
|
value_parts.append(h_cont_value)
|
|
|
|
# Explicit continuous mask embedding (fused later)
|
|
if self.cont_mask_proj is not None:
|
|
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
|
|
mask_parts.append(h_cont_mask)
|
|
|
|
if self.n_cate > 0 and cate_features is not None:
|
|
if cate_features.dim() != 3:
|
|
raise ValueError(
|
|
"cate_features must be 3D tensor (B, L, n_cate)")
|
|
B, L, D_cate = cate_features.shape
|
|
if D_cate != self.n_cate:
|
|
raise ValueError(
|
|
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
|
|
|
|
for i in range(self.n_cate):
|
|
cate_feat = cate_features[:, :, i]
|
|
cate_embd = self.cate_embds[i]
|
|
cate_mask_embd = self.cate_mask_embds[i]
|
|
|
|
cate_value = cate_embd(
|
|
torch.clamp(cate_feat, min=0))
|
|
cate_mask = (cate_feat > 0).long()
|
|
cate_mask_value = cate_mask_embd(cate_mask)
|
|
|
|
value_parts.append(cate_value)
|
|
mask_parts.append(cate_mask_value)
|
|
|
|
if not value_parts:
|
|
if cont_features is not None:
|
|
B, L = cont_features.shape[:2]
|
|
device = cont_features.device
|
|
elif cate_features is not None:
|
|
B, L = cate_features.shape[:2]
|
|
device = cate_features.device
|
|
else:
|
|
raise ValueError("No features provided to TabularEncoder.")
|
|
return torch.zeros(B, L, self.n_embd, device=device)
|
|
|
|
# Aggregate across feature groups (continuous block counts as one part;
|
|
# each categorical feature counts as one part).
|
|
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
|
|
|
|
if mask_parts:
|
|
h_mask = torch.stack(mask_parts, dim=0).mean(
|
|
dim=0) # (B, L, n_embd)
|
|
else:
|
|
h_mask = torch.zeros_like(h_value)
|
|
|
|
# Fuse by concatenation + MLP projection
|
|
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
|
|
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
|
|
h_out = self.out_ln(h_out)
|
|
return h_out
|
|
|
|
|
|
class AutoDiscretization(nn.Module):
|
|
"""AutoDiscretization / soft-binning for continuous tabular scalars.
|
|
|
|
For each feature scalar $x$, compute a soft assignment over `n_bins`:
|
|
p = softmax(x * w + b)
|
|
Then compute the embedding as a weighted sum of learnable bin embeddings:
|
|
emb = sum_k p_k * E_k
|
|
|
|
Shapes:
|
|
Input: (N, n_features)
|
|
Output: (N, n_features, n_embd)
|
|
"""
|
|
|
|
def __init__(self, n_features: int, n_bins: int, n_embd: int):
|
|
super().__init__()
|
|
if n_features <= 0:
|
|
raise ValueError("n_features must be > 0")
|
|
if n_bins <= 1:
|
|
raise ValueError("n_bins must be > 1")
|
|
if n_embd <= 0:
|
|
raise ValueError("n_embd must be > 0")
|
|
|
|
self.n_features = n_features
|
|
self.n_bins = n_bins
|
|
self.n_embd = n_embd
|
|
|
|
# Per-feature, per-bin affine transform to produce logits
|
|
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
|
|
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
|
|
|
|
# Learnable embeddings for each (feature, bin)
|
|
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
|
nn.init.zeros_(self.bias)
|
|
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if x.dim() != 2:
|
|
raise ValueError(
|
|
"AutoDiscretization expects input of shape (N, n_features)")
|
|
if x.size(1) != self.n_features:
|
|
raise ValueError(
|
|
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
|
|
)
|
|
|
|
# logits: (N, n_features, n_bins)
|
|
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
|
|
self.bias.unsqueeze(0)
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
|
# Weighted sum over bins -> (N, n_features, n_embd)
|
|
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
|
|
return emb
|
|
|
|
|
|
class DelphiBERT(nn.Module):
|
|
"""
|
|
DelphiBERT model for tabular time series data.
|
|
|
|
Args:
|
|
n_embd (int): Embedding dimension.
|
|
n_head (int): Number of attention heads.
|
|
n_layer (int): Number of transformer blocks.
|
|
pdrop (float): Dropout probability.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
n_disease: int,
|
|
n_embd: int,
|
|
n_head: int,
|
|
n_layer: int,
|
|
n_cont: int = 0,
|
|
n_cate: int = 0,
|
|
cate_dims: Optional[List[int]] = None,
|
|
age_encoder_type: str = 'sinusoidal',
|
|
pdrop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
if n_cont > 0 or n_cate > 0:
|
|
if cate_dims is None:
|
|
raise ValueError(
|
|
"cate_dims must be provided if n_cate > 0"
|
|
)
|
|
self.tabular_encoder = TabularEncoder(
|
|
n_embd=n_embd,
|
|
n_cont=n_cont,
|
|
n_cate=n_cate,
|
|
cate_dims=cate_dims,
|
|
)
|
|
else:
|
|
self.tabular_encoder = None
|
|
self.vocab_size = n_disease + 4
|
|
self.n_disease = n_disease
|
|
self.n_embd = n_embd
|
|
self.n_head = n_head
|
|
|
|
self.token_embedding = nn.Embedding(
|
|
self.vocab_size, n_embd, padding_idx=0)
|
|
if age_encoder_type == 'sinusoidal':
|
|
self.age_encoder = AgeSinusoidalEncoder(n_embd)
|
|
elif age_encoder_type == 'mlp':
|
|
self.age_encoder = AgeMLPEncoder(n_embd)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported age_encoder_type: {age_encoder_type}"
|
|
)
|
|
|
|
self.sex_embedding = nn.Embedding(2, n_embd)
|
|
|
|
self.blocks = nn.ModuleList([
|
|
Block(
|
|
n_embd=n_embd,
|
|
n_head=n_head,
|
|
pdrop=pdrop,
|
|
) for _ in range(n_layer)
|
|
])
|
|
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
|
|
def forward(
|
|
self,
|
|
event_seq: torch.Tensor,
|
|
time_seq: torch.Tensor,
|
|
sex: torch.Tensor,
|
|
cont_seq: Optional[torch.Tensor] = None,
|
|
cate_seq: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass of DelphiBERT.
|
|
|
|
Inputs:
|
|
event_seq: (B, L) long tensor of token IDs.
|
|
time_seq: (B, L) float tensor of ages/times.
|
|
sex: (B,) long tensor of sex
|
|
cont_seq: (B, Lc, n_cont) float tensor of continuous features.
|
|
cate_seq: (B, Lc, n_cate) long tensor of categorical features.
|
|
Returns:
|
|
(B, L, n_embd) output embeddings.
|
|
"""
|
|
B, L = event_seq.shape
|
|
token_emb = self.token_embedding(event_seq) # (B, L, n_embd)
|
|
age_emb = self.age_encoder(time_seq) # (B, L, n_embd)
|
|
sex_emb = self.sex_embedding(sex.unsqueeze(-1)) # (B, n_embd)
|
|
if self.tabular_encoder is not None and cont_seq is not None and cate_seq is not None:
|
|
tabular_emb = self.tabular_encoder(
|
|
cont_seq, cate_seq) # (B, L, n_embd)
|
|
mask = (event_seq == 2)
|
|
Lc = tabular_emb.size(1)
|
|
D = tabular_emb.size(2)
|
|
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
|
|
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
|
|
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
|
|
tab_inject = tabular_emb.gather(
|
|
dim=1,
|
|
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
|
|
) # (B, L, n_embd)
|
|
final_embds = torch.where(
|
|
mask.unsqueeze(-1), tab_inject, token_emb)
|
|
h = final_embds + age_emb + sex_emb
|
|
else:
|
|
h = token_emb + age_emb + sex_emb
|
|
|
|
is_padding = (event_seq == 0)
|
|
attn_mask = is_padding.view(B, 1, 1, L) # (B, 1, 1, L)
|
|
|
|
for block in self.blocks:
|
|
h = block(h, attn_mask=attn_mask)
|
|
h = self.ln_f(h)
|
|
cls_output = h[:, 0, :]
|
|
return cls_output
|