import numpy as np from typing import Optional, List from backbones import Block from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder import torch.nn.functional as F import torch.nn as nn import torch class TabularEncoder(nn.Module): """ Encoder for tabular features (continuous and categorical). Args: n_embd (int): Embedding dimension. n_cont (int): Number of continuous features. n_cate (int): Number of categorical features. cate_dims (List[int]): List of dimensions for each categorical feature. n_bins (int): Number of soft bins for continuous AutoDiscretization. """ def __init__( self, n_embd: int, n_cont: int, n_cate: int, cate_dims: List[int], n_bins: int = 16, ): super().__init__() self.n_embd = n_embd self.n_cont = n_cont self.n_cate = n_cate # Continuous feature path # - BatchNorm on raw (NaN-filled) continuous values # - AutoDiscretization (soft binning) per feature if n_cont > 0: self.cont_bn = nn.BatchNorm1d(n_cont) self.cont_discretizer = AutoDiscretization( n_features=n_cont, n_bins=n_bins, n_embd=n_embd, ) else: self.cont_bn = None self.cont_discretizer = None if n_cate > 0: assert len(cate_dims) == n_cate, \ "Length of cate_dims must match n_cate" self.cate_embds = nn.ModuleList([ nn.Embedding(dim, n_embd) for dim in cate_dims ]) self.cate_mask_embds = nn.ModuleList([ nn.Embedding(2, n_embd) for _ in range(n_cate) ]) else: self.cate_embds = None self.cate_mask_embds = None self.cont_mask_proj = ( nn.Linear(n_cont, n_embd) if n_cont > 0 else None ) # Fuse aggregated value + aggregated mask via MLP self.fuse_mlp = nn.Sequential( nn.Linear(2 * n_embd, 2 * n_embd), nn.GELU(), nn.Linear(2 * n_embd, n_embd), ) self.apply(self._init_weights) self.out_ln = nn.LayerNorm(n_embd) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, cont_features: Optional[torch.Tensor], cate_features: Optional[torch.Tensor], ) -> torch.Tensor: """Encode tabular features into a per-timestep embedding. Inputs: cont_features: (B, L, n_cont) float tensor; NaN indicates missing. cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad. Returns: (B, L, n_embd) encoded embedding. """ if self.n_cont == 0 and self.n_cate == 0: # infer (B, L) from whichever input is not None if cont_features is not None: B, L = cont_features.shape[:2] device = cont_features.device elif cate_features is not None: B, L = cate_features.shape[:2] device = cate_features.device else: raise ValueError( "TabularEncoder received no features but cannot infer (B, L)." ) return torch.zeros(B, L, self.n_embd, device=device) value_parts: List[torch.Tensor] = [] mask_parts: List[torch.Tensor] = [] if self.n_cont > 0 and cont_features is not None: if cont_features.dim() != 3: raise ValueError( "cont_features must be 3D tensor (B, L, n_cont)") B, L, D_cont = cont_features.shape if D_cont != self.n_cont: raise ValueError( f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}") # Missingness mask: 1 for valid, 0 for missing cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont) # BatchNorm cannot handle NaNs; fill missing with 0 before BN. cont_filled = torch.nan_to_num( cont_features, nan=0.0) # (B, L, n_cont) # Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C) cont_flat = cont_filled.reshape(-1, self.n_cont) cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont) # Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd) cont_emb_flat = self.cont_discretizer(cont_norm_flat) cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd) # Mask-out missing continuous features before aggregating across features # (B, L, n_cont, n_embd) cont_emb = cont_emb * cont_mask.unsqueeze(-1) denom = cont_mask.sum( dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1) h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd) value_parts.append(h_cont_value) # Explicit continuous mask embedding (fused later) if self.cont_mask_proj is not None: h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd) mask_parts.append(h_cont_mask) if self.n_cate > 0 and cate_features is not None: if cate_features.dim() != 3: raise ValueError( "cate_features must be 3D tensor (B, L, n_cate)") B, L, D_cate = cate_features.shape if D_cate != self.n_cate: raise ValueError( f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}") for i in range(self.n_cate): cate_feat = cate_features[:, :, i] cate_embd = self.cate_embds[i] cate_mask_embd = self.cate_mask_embds[i] cate_value = cate_embd( torch.clamp(cate_feat, min=0)) cate_mask = (cate_feat > 0).long() cate_mask_value = cate_mask_embd(cate_mask) value_parts.append(cate_value) mask_parts.append(cate_mask_value) if not value_parts: if cont_features is not None: B, L = cont_features.shape[:2] device = cont_features.device elif cate_features is not None: B, L = cate_features.shape[:2] device = cate_features.device else: raise ValueError("No features provided to TabularEncoder.") return torch.zeros(B, L, self.n_embd, device=device) # Aggregate across feature groups (continuous block counts as one part; # each categorical feature counts as one part). h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd) if mask_parts: h_mask = torch.stack(mask_parts, dim=0).mean( dim=0) # (B, L, n_embd) else: h_mask = torch.zeros_like(h_value) # Fuse by concatenation + MLP projection h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd) h_out = self.fuse_mlp(h_fused) # (B, L, n_embd) h_out = self.out_ln(h_out) return h_out class AutoDiscretization(nn.Module): """AutoDiscretization / soft-binning for continuous tabular scalars. For each feature scalar $x$, compute a soft assignment over `n_bins`: p = softmax(x * w + b) Then compute the embedding as a weighted sum of learnable bin embeddings: emb = sum_k p_k * E_k Shapes: Input: (N, n_features) Output: (N, n_features, n_embd) """ def __init__(self, n_features: int, n_bins: int, n_embd: int): super().__init__() if n_features <= 0: raise ValueError("n_features must be > 0") if n_bins <= 1: raise ValueError("n_bins must be > 1") if n_embd <= 0: raise ValueError("n_embd must be > 0") self.n_features = n_features self.n_bins = n_bins self.n_embd = n_embd # Per-feature, per-bin affine transform to produce logits self.weight = nn.Parameter(torch.empty(n_features, n_bins)) self.bias = nn.Parameter(torch.empty(n_features, n_bins)) # Learnable embeddings for each (feature, bin) self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd)) self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.weight, mean=0.0, std=0.02) nn.init.zeros_(self.bias) nn.init.normal_(self.bin_emb, mean=0.0, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() != 2: raise ValueError( "AutoDiscretization expects input of shape (N, n_features)") if x.size(1) != self.n_features: raise ValueError( f"Expected x.size(1) == {self.n_features}, got {x.size(1)}" ) # logits: (N, n_features, n_bins) logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \ self.bias.unsqueeze(0) probs = torch.softmax(logits, dim=-1) # Weighted sum over bins -> (N, n_features, n_embd) emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2) return emb class DelphiBERT(nn.Module): """ DelphiBERT model for tabular time series data. Args: n_embd (int): Embedding dimension. n_head (int): Number of attention heads. n_layer (int): Number of transformer blocks. pdrop (float): Dropout probability. """ def __init__( self, n_disease: int, n_embd: int, n_head: int, n_layer: int, n_cont: int = 0, n_cate: int = 0, cate_dims: Optional[List[int]] = None, age_encoder_type: str = 'sinusoidal', pdrop: float = 0.0, ): super().__init__() if n_cont > 0 or n_cate > 0: if cate_dims is None: raise ValueError( "cate_dims must be provided if n_cate > 0" ) self.tabular_encoder = TabularEncoder( n_embd=n_embd, n_cont=n_cont, n_cate=n_cate, cate_dims=cate_dims, ) else: self.tabular_encoder = None self.vocab_size = n_disease + 4 self.n_disease = n_disease self.n_embd = n_embd self.n_head = n_head self.token_embedding = nn.Embedding( self.vocab_size, n_embd, padding_idx=0) if age_encoder_type == 'sinusoidal': self.age_encoder = AgeSinusoidalEncoder(n_embd) elif age_encoder_type == 'mlp': self.age_encoder = AgeMLPEncoder(n_embd) else: raise ValueError( f"Unsupported age_encoder_type: {age_encoder_type}" ) self.sex_embedding = nn.Embedding(2, n_embd) self.blocks = nn.ModuleList([ Block( n_embd=n_embd, n_head=n_head, pdrop=pdrop, ) for _ in range(n_layer) ]) self.ln_f = nn.LayerNorm(n_embd) def forward( self, event_seq: torch.Tensor, time_seq: torch.Tensor, sex: torch.Tensor, cont_seq: Optional[torch.Tensor] = None, cate_seq: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass of DelphiBERT. Inputs: event_seq: (B, L) long tensor of token IDs. time_seq: (B, L) float tensor of ages/times. sex: (B,) long tensor of sex cont_seq: (B, Lc, n_cont) float tensor of continuous features. cate_seq: (B, Lc, n_cate) long tensor of categorical features. Returns: (B, L, n_embd) output embeddings. """ B, L = event_seq.shape token_emb = self.token_embedding(event_seq) # (B, L, n_embd) age_emb = self.age_encoder(time_seq) # (B, L, n_embd) sex_emb = self.sex_embedding(sex.unsqueeze(-1)) # (B, n_embd) if self.tabular_encoder is not None and cont_seq is not None and cate_seq is not None: tabular_emb = self.tabular_encoder( cont_seq, cate_seq) # (B, L, n_embd) mask = (event_seq == 2) Lc = tabular_emb.size(1) D = tabular_emb.size(2) occ = torch.cumsum(mask.to(torch.long), dim=1) - 1 tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0)) tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L) tab_inject = tabular_emb.gather( dim=1, index=tab_idx.unsqueeze(-1).expand(-1, -1, D) ) # (B, L, n_embd) final_embds = torch.where( mask.unsqueeze(-1), tab_inject, token_emb) h = final_embds + age_emb + sex_emb else: h = token_emb + age_emb + sex_emb is_padding = (event_seq == 0) attn_mask = is_padding.view(B, 1, 1, L) # (B, 1, 1, L) for block in self.blocks: h = block(h, attn_mask=attn_mask) h = self.ln_f(h) cls_output = h[:, 0, :] return cls_output