377 lines
13 KiB
Python
377 lines
13 KiB
Python
|
|
import numpy as np
|
||
|
|
from typing import Optional, List
|
||
|
|
from backbones import Block
|
||
|
|
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||
|
|
import torch.nn.functional as F
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
class TabularEncoder(nn.Module):
|
||
|
|
"""
|
||
|
|
Encoder for tabular features (continuous and categorical).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
n_embd (int): Embedding dimension.
|
||
|
|
n_cont (int): Number of continuous features.
|
||
|
|
n_cate (int): Number of categorical features.
|
||
|
|
cate_dims (List[int]): List of dimensions for each categorical feature.
|
||
|
|
n_bins (int): Number of soft bins for continuous AutoDiscretization.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
n_embd: int,
|
||
|
|
n_cont: int,
|
||
|
|
n_cate: int,
|
||
|
|
cate_dims: List[int],
|
||
|
|
n_bins: int = 16,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
self.n_embd = n_embd
|
||
|
|
self.n_cont = n_cont
|
||
|
|
self.n_cate = n_cate
|
||
|
|
|
||
|
|
# Continuous feature path
|
||
|
|
# - BatchNorm on raw (NaN-filled) continuous values
|
||
|
|
# - AutoDiscretization (soft binning) per feature
|
||
|
|
if n_cont > 0:
|
||
|
|
self.cont_bn = nn.BatchNorm1d(n_cont)
|
||
|
|
self.cont_discretizer = AutoDiscretization(
|
||
|
|
n_features=n_cont,
|
||
|
|
n_bins=n_bins,
|
||
|
|
n_embd=n_embd,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self.cont_bn = None
|
||
|
|
self.cont_discretizer = None
|
||
|
|
|
||
|
|
if n_cate > 0:
|
||
|
|
assert len(cate_dims) == n_cate, \
|
||
|
|
"Length of cate_dims must match n_cate"
|
||
|
|
self.cate_embds = nn.ModuleList([
|
||
|
|
nn.Embedding(dim, n_embd) for dim in cate_dims
|
||
|
|
])
|
||
|
|
self.cate_mask_embds = nn.ModuleList([
|
||
|
|
nn.Embedding(2, n_embd) for _ in range(n_cate)
|
||
|
|
])
|
||
|
|
else:
|
||
|
|
self.cate_embds = None
|
||
|
|
self.cate_mask_embds = None
|
||
|
|
|
||
|
|
self.cont_mask_proj = (
|
||
|
|
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fuse aggregated value + aggregated mask via MLP
|
||
|
|
self.fuse_mlp = nn.Sequential(
|
||
|
|
nn.Linear(2 * n_embd, 2 * n_embd),
|
||
|
|
nn.GELU(),
|
||
|
|
nn.Linear(2 * n_embd, n_embd),
|
||
|
|
)
|
||
|
|
|
||
|
|
self.apply(self._init_weights)
|
||
|
|
self.out_ln = nn.LayerNorm(n_embd)
|
||
|
|
|
||
|
|
def _init_weights(self, module):
|
||
|
|
if isinstance(module, nn.Linear):
|
||
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
|
if module.bias is not None:
|
||
|
|
torch.nn.init.zeros_(module.bias)
|
||
|
|
elif isinstance(module, nn.Embedding):
|
||
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
cont_features: Optional[torch.Tensor],
|
||
|
|
cate_features: Optional[torch.Tensor],
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""Encode tabular features into a per-timestep embedding.
|
||
|
|
|
||
|
|
Inputs:
|
||
|
|
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
|
||
|
|
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(B, L, n_embd) encoded embedding.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if self.n_cont == 0 and self.n_cate == 0:
|
||
|
|
# infer (B, L) from whichever input is not None
|
||
|
|
if cont_features is not None:
|
||
|
|
B, L = cont_features.shape[:2]
|
||
|
|
device = cont_features.device
|
||
|
|
elif cate_features is not None:
|
||
|
|
B, L = cate_features.shape[:2]
|
||
|
|
device = cate_features.device
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
"TabularEncoder received no features but cannot infer (B, L)."
|
||
|
|
)
|
||
|
|
return torch.zeros(B, L, self.n_embd, device=device)
|
||
|
|
|
||
|
|
value_parts: List[torch.Tensor] = []
|
||
|
|
mask_parts: List[torch.Tensor] = []
|
||
|
|
|
||
|
|
if self.n_cont > 0 and cont_features is not None:
|
||
|
|
if cont_features.dim() != 3:
|
||
|
|
raise ValueError(
|
||
|
|
"cont_features must be 3D tensor (B, L, n_cont)")
|
||
|
|
B, L, D_cont = cont_features.shape
|
||
|
|
if D_cont != self.n_cont:
|
||
|
|
raise ValueError(
|
||
|
|
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
|
||
|
|
|
||
|
|
# Missingness mask: 1 for valid, 0 for missing
|
||
|
|
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
|
||
|
|
|
||
|
|
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
|
||
|
|
cont_filled = torch.nan_to_num(
|
||
|
|
cont_features, nan=0.0) # (B, L, n_cont)
|
||
|
|
|
||
|
|
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
|
||
|
|
cont_flat = cont_filled.reshape(-1, self.n_cont)
|
||
|
|
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
|
||
|
|
|
||
|
|
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
|
||
|
|
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
|
||
|
|
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
|
||
|
|
|
||
|
|
# Mask-out missing continuous features before aggregating across features
|
||
|
|
# (B, L, n_cont, n_embd)
|
||
|
|
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
|
||
|
|
denom = cont_mask.sum(
|
||
|
|
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
|
||
|
|
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
|
||
|
|
value_parts.append(h_cont_value)
|
||
|
|
|
||
|
|
# Explicit continuous mask embedding (fused later)
|
||
|
|
if self.cont_mask_proj is not None:
|
||
|
|
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
|
||
|
|
mask_parts.append(h_cont_mask)
|
||
|
|
|
||
|
|
if self.n_cate > 0 and cate_features is not None:
|
||
|
|
if cate_features.dim() != 3:
|
||
|
|
raise ValueError(
|
||
|
|
"cate_features must be 3D tensor (B, L, n_cate)")
|
||
|
|
B, L, D_cate = cate_features.shape
|
||
|
|
if D_cate != self.n_cate:
|
||
|
|
raise ValueError(
|
||
|
|
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
|
||
|
|
|
||
|
|
for i in range(self.n_cate):
|
||
|
|
cate_feat = cate_features[:, :, i]
|
||
|
|
cate_embd = self.cate_embds[i]
|
||
|
|
cate_mask_embd = self.cate_mask_embds[i]
|
||
|
|
|
||
|
|
cate_value = cate_embd(
|
||
|
|
torch.clamp(cate_feat, min=0))
|
||
|
|
cate_mask = (cate_feat > 0).long()
|
||
|
|
cate_mask_value = cate_mask_embd(cate_mask)
|
||
|
|
|
||
|
|
value_parts.append(cate_value)
|
||
|
|
mask_parts.append(cate_mask_value)
|
||
|
|
|
||
|
|
if not value_parts:
|
||
|
|
if cont_features is not None:
|
||
|
|
B, L = cont_features.shape[:2]
|
||
|
|
device = cont_features.device
|
||
|
|
elif cate_features is not None:
|
||
|
|
B, L = cate_features.shape[:2]
|
||
|
|
device = cate_features.device
|
||
|
|
else:
|
||
|
|
raise ValueError("No features provided to TabularEncoder.")
|
||
|
|
return torch.zeros(B, L, self.n_embd, device=device)
|
||
|
|
|
||
|
|
# Aggregate across feature groups (continuous block counts as one part;
|
||
|
|
# each categorical feature counts as one part).
|
||
|
|
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
|
||
|
|
|
||
|
|
if mask_parts:
|
||
|
|
h_mask = torch.stack(mask_parts, dim=0).mean(
|
||
|
|
dim=0) # (B, L, n_embd)
|
||
|
|
else:
|
||
|
|
h_mask = torch.zeros_like(h_value)
|
||
|
|
|
||
|
|
# Fuse by concatenation + MLP projection
|
||
|
|
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
|
||
|
|
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
|
||
|
|
h_out = self.out_ln(h_out)
|
||
|
|
return h_out
|
||
|
|
|
||
|
|
|
||
|
|
class AutoDiscretization(nn.Module):
|
||
|
|
"""AutoDiscretization / soft-binning for continuous tabular scalars.
|
||
|
|
|
||
|
|
For each feature scalar $x$, compute a soft assignment over `n_bins`:
|
||
|
|
p = softmax(x * w + b)
|
||
|
|
Then compute the embedding as a weighted sum of learnable bin embeddings:
|
||
|
|
emb = sum_k p_k * E_k
|
||
|
|
|
||
|
|
Shapes:
|
||
|
|
Input: (N, n_features)
|
||
|
|
Output: (N, n_features, n_embd)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, n_features: int, n_bins: int, n_embd: int):
|
||
|
|
super().__init__()
|
||
|
|
if n_features <= 0:
|
||
|
|
raise ValueError("n_features must be > 0")
|
||
|
|
if n_bins <= 1:
|
||
|
|
raise ValueError("n_bins must be > 1")
|
||
|
|
if n_embd <= 0:
|
||
|
|
raise ValueError("n_embd must be > 0")
|
||
|
|
|
||
|
|
self.n_features = n_features
|
||
|
|
self.n_bins = n_bins
|
||
|
|
self.n_embd = n_embd
|
||
|
|
|
||
|
|
# Per-feature, per-bin affine transform to produce logits
|
||
|
|
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
|
||
|
|
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
|
||
|
|
|
||
|
|
# Learnable embeddings for each (feature, bin)
|
||
|
|
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
|
||
|
|
|
||
|
|
self.reset_parameters()
|
||
|
|
|
||
|
|
def reset_parameters(self) -> None:
|
||
|
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||
|
|
nn.init.zeros_(self.bias)
|
||
|
|
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
|
||
|
|
|
||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
|
if x.dim() != 2:
|
||
|
|
raise ValueError(
|
||
|
|
"AutoDiscretization expects input of shape (N, n_features)")
|
||
|
|
if x.size(1) != self.n_features:
|
||
|
|
raise ValueError(
|
||
|
|
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# logits: (N, n_features, n_bins)
|
||
|
|
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
|
||
|
|
self.bias.unsqueeze(0)
|
||
|
|
probs = torch.softmax(logits, dim=-1)
|
||
|
|
|
||
|
|
# Weighted sum over bins -> (N, n_features, n_embd)
|
||
|
|
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
|
||
|
|
return emb
|
||
|
|
|
||
|
|
|
||
|
|
class DelphiBERT(nn.Module):
|
||
|
|
"""
|
||
|
|
DelphiBERT model for tabular time series data.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
n_embd (int): Embedding dimension.
|
||
|
|
n_head (int): Number of attention heads.
|
||
|
|
n_layer (int): Number of transformer blocks.
|
||
|
|
pdrop (float): Dropout probability.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
n_disease: int,
|
||
|
|
n_embd: int,
|
||
|
|
n_head: int,
|
||
|
|
n_layer: int,
|
||
|
|
n_cont: int = 0,
|
||
|
|
n_cate: int = 0,
|
||
|
|
cate_dims: Optional[List[int]] = None,
|
||
|
|
age_encoder_type: str = 'sinusoidal',
|
||
|
|
pdrop: float = 0.0,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
if n_cont > 0 or n_cate > 0:
|
||
|
|
if cate_dims is None:
|
||
|
|
raise ValueError(
|
||
|
|
"cate_dims must be provided if n_cate > 0"
|
||
|
|
)
|
||
|
|
self.tabular_encoder = TabularEncoder(
|
||
|
|
n_embd=n_embd,
|
||
|
|
n_cont=n_cont,
|
||
|
|
n_cate=n_cate,
|
||
|
|
cate_dims=cate_dims,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self.tabular_encoder = None
|
||
|
|
self.vocab_size = n_disease + 4
|
||
|
|
self.n_disease = n_disease
|
||
|
|
self.n_embd = n_embd
|
||
|
|
self.n_head = n_head
|
||
|
|
|
||
|
|
self.token_embedding = nn.Embedding(
|
||
|
|
self.vocab_size, n_embd, padding_idx=0)
|
||
|
|
if age_encoder_type == 'sinusoidal':
|
||
|
|
self.age_encoder = AgeSinusoidalEncoder(n_embd)
|
||
|
|
elif age_encoder_type == 'mlp':
|
||
|
|
self.age_encoder = AgeMLPEncoder(n_embd)
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unsupported age_encoder_type: {age_encoder_type}"
|
||
|
|
)
|
||
|
|
|
||
|
|
self.sex_embedding = nn.Embedding(2, n_embd)
|
||
|
|
|
||
|
|
self.blocks = nn.ModuleList([
|
||
|
|
Block(
|
||
|
|
n_embd=n_embd,
|
||
|
|
n_head=n_head,
|
||
|
|
pdrop=pdrop,
|
||
|
|
) for _ in range(n_layer)
|
||
|
|
])
|
||
|
|
|
||
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
event_seq: torch.Tensor,
|
||
|
|
time_seq: torch.Tensor,
|
||
|
|
sex: torch.Tensor,
|
||
|
|
cont_seq: Optional[torch.Tensor] = None,
|
||
|
|
cate_seq: Optional[torch.Tensor] = None,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""Forward pass of DelphiBERT.
|
||
|
|
|
||
|
|
Inputs:
|
||
|
|
event_seq: (B, L) long tensor of token IDs.
|
||
|
|
time_seq: (B, L) float tensor of ages/times.
|
||
|
|
sex: (B,) long tensor of sex
|
||
|
|
cont_seq: (B, Lc, n_cont) float tensor of continuous features.
|
||
|
|
cate_seq: (B, Lc, n_cate) long tensor of categorical features.
|
||
|
|
Returns:
|
||
|
|
(B, L, n_embd) output embeddings.
|
||
|
|
"""
|
||
|
|
B, L = event_seq.shape
|
||
|
|
token_emb = self.token_embedding(event_seq) # (B, L, n_embd)
|
||
|
|
age_emb = self.age_encoder(time_seq) # (B, L, n_embd)
|
||
|
|
sex_emb = self.sex_embedding(sex.unsqueeze(-1)) # (B, n_embd)
|
||
|
|
if self.tabular_encoder is not None and cont_seq is not None and cate_seq is not None:
|
||
|
|
tabular_emb = self.tabular_encoder(
|
||
|
|
cont_seq, cate_seq) # (B, L, n_embd)
|
||
|
|
mask = (event_seq == 2)
|
||
|
|
Lc = tabular_emb.size(1)
|
||
|
|
D = tabular_emb.size(2)
|
||
|
|
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
|
||
|
|
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
|
||
|
|
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
|
||
|
|
tab_inject = tabular_emb.gather(
|
||
|
|
dim=1,
|
||
|
|
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
|
||
|
|
) # (B, L, n_embd)
|
||
|
|
final_embds = torch.where(
|
||
|
|
mask.unsqueeze(-1), tab_inject, token_emb)
|
||
|
|
h = final_embds + age_emb + sex_emb
|
||
|
|
else:
|
||
|
|
h = token_emb + age_emb + sex_emb
|
||
|
|
|
||
|
|
is_padding = (event_seq == 0)
|
||
|
|
attn_mask = is_padding.view(B, 1, 1, L) # (B, 1, 1, L)
|
||
|
|
|
||
|
|
for block in self.blocks:
|
||
|
|
h = block(h, attn_mask=attn_mask)
|
||
|
|
h = self.ln_f(h)
|
||
|
|
cls_output = h[:, 0, :]
|
||
|
|
return cls_output
|