Files
DeepHealth/model.py

543 lines
20 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block
from typing import Optional, List
import numpy as np
class TabularEncoder(nn.Module):
"""
Encoder for tabular features (continuous and categorical).
Args:
n_embd (int): Embedding dimension.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
n_bins (int): Number of soft bins for continuous AutoDiscretization.
"""
def __init__(
self,
n_embd: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
n_bins: int = 16,
):
super().__init__()
self.n_embd = n_embd
self.n_cont = n_cont
self.n_cate = n_cate
# Continuous feature path
# - BatchNorm on raw (NaN-filled) continuous values
# - AutoDiscretization (soft binning) per feature
if n_cont > 0:
self.cont_bn = nn.BatchNorm1d(n_cont)
self.cont_discretizer = AutoDiscretization(
n_features=n_cont,
n_bins=n_bins,
n_embd=n_embd,
)
else:
self.cont_bn = None
self.cont_discretizer = None
if n_cate > 0:
assert len(cate_dims) == n_cate, \
"Length of cate_dims must match n_cate"
self.cate_embds = nn.ModuleList([
nn.Embedding(dim, n_embd) for dim in cate_dims
])
self.cate_mask_embds = nn.ModuleList([
nn.Embedding(2, n_embd) for _ in range(n_cate)
])
else:
self.cate_embds = None
self.cate_mask_embds = None
self.cont_mask_proj = (
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
)
# Fuse aggregated value + aggregated mask via MLP
self.fuse_mlp = nn.Sequential(
nn.Linear(2 * n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, n_embd),
)
self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor],
) -> torch.Tensor:
"""Encode tabular features into a per-timestep embedding.
Inputs:
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
Returns:
(B, L, n_embd) encoded embedding.
"""
if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError(
"TabularEncoder received no features but cannot infer (B, L)."
)
return torch.zeros(B, L, self.n_embd, device=device)
value_parts: List[torch.Tensor] = []
mask_parts: List[torch.Tensor] = []
if self.n_cont > 0 and cont_features is not None:
if cont_features.dim() != 3:
raise ValueError(
"cont_features must be 3D tensor (B, L, n_cont)")
B, L, D_cont = cont_features.shape
if D_cont != self.n_cont:
raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
# Missingness mask: 1 for valid, 0 for missing
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
cont_filled = torch.nan_to_num(
cont_features, nan=0.0) # (B, L, n_cont)
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
cont_flat = cont_filled.reshape(-1, self.n_cont)
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
# Mask-out missing continuous features before aggregating across features
# (B, L, n_cont, n_embd)
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
denom = cont_mask.sum(
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
value_parts.append(h_cont_value)
# Explicit continuous mask embedding (fused later)
if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None:
if cate_features.dim() != 3:
raise ValueError(
"cate_features must be 3D tensor (B, L, n_cate)")
B, L, D_cate = cate_features.shape
if D_cate != self.n_cate:
raise ValueError(
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
for i in range(self.n_cate):
cate_feat = cate_features[:, :, i]
cate_embd = self.cate_embds[i]
cate_mask_embd = self.cate_mask_embds[i]
cate_value = cate_embd(
torch.clamp(cate_feat, min=0))
cate_mask = (cate_feat > 0).long()
cate_mask_value = cate_mask_embd(cate_mask)
value_parts.append(cate_value)
mask_parts.append(cate_mask_value)
if not value_parts:
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device)
# Aggregate across feature groups (continuous block counts as one part;
# each categorical feature counts as one part).
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
if mask_parts:
h_mask = torch.stack(mask_parts, dim=0).mean(
dim=0) # (B, L, n_embd)
else:
h_mask = torch.zeros_like(h_value)
# Fuse by concatenation + MLP projection
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
h_out = self.out_ln(h_out)
return h_out
class AutoDiscretization(nn.Module):
"""AutoDiscretization / soft-binning for continuous tabular scalars.
For each feature scalar $x$, compute a soft assignment over `n_bins`:
p = softmax(x * w + b)
Then compute the embedding as a weighted sum of learnable bin embeddings:
emb = sum_k p_k * E_k
Shapes:
Input: (N, n_features)
Output: (N, n_features, n_embd)
"""
def __init__(self, n_features: int, n_bins: int, n_embd: int):
super().__init__()
if n_features <= 0:
raise ValueError("n_features must be > 0")
if n_bins <= 1:
raise ValueError("n_bins must be > 1")
if n_embd <= 0:
raise ValueError("n_embd must be > 0")
self.n_features = n_features
self.n_bins = n_bins
self.n_embd = n_embd
# Per-feature, per-bin affine transform to produce logits
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
# Learnable embeddings for each (feature, bin)
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.bias)
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 2:
raise ValueError(
"AutoDiscretization expects input of shape (N, n_features)")
if x.size(1) != self.n_features:
raise ValueError(
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
)
# logits: (N, n_features, n_bins)
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
self.bias.unsqueeze(0)
probs = torch.softmax(logits, dim=-1)
# Weighted sum over bins -> (N, n_features, n_embd)
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
return emb
def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
B, L = event_seq.shape
device = event_seq.device
key_is_valid = (event_seq != 0)
cache = getattr(_build_time_padding_mask, "_cache", None)
if cache is None:
cache = {}
setattr(_build_time_padding_mask, "_cache", cache)
cache_key = (str(device), L)
causal = cache.get(cache_key)
if causal is None:
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
cache[cache_key] = causal
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
attn_mask = causal | key_pad # (B,1,L,L)
return attn_mask
class DelphiFork(nn.Module):
"""
DelphiFork model for time-to-event prediction.
Args:
n_disease (int): Number of disease tokens.
n_tech_tokens (int): Number of technical tokens.
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
n_layer (int): Number of transformer layers.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
age_encoder_type (str): Type of age encoder ("sinusoidal" or "mlp").
pdrop (float): Dropout probability.
token_pdrop (float): Token dropout probability.
n_dim (int): Dimension of theta parameters.
"""
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x
class SapDelphi(nn.Module):
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
if pretrained_weights_path is not None:
print(
f"Loading pretrained embeddings from {pretrained_weights_path}...")
bert_weights = np.load(pretrained_weights_path)
bert_weights = torch.tensor(bert_weights, dtype=torch.float32)
vocab_dim = bert_weights.shape[1] # 通常是 768
pad_emb = torch.zeros(1, vocab_dim)
tech_embs = nn.init.normal_(torch.empty(
n_tech_tokens-1, vocab_dim))
full_emb_weights = torch.cat(
[pad_emb, tech_embs, bert_weights], dim=0)
self.token_embedding = nn.Embedding.from_pretrained(
full_emb_weights, freeze=freeze_embeddings)
print("Pretrained embeddings loaded.")
if vocab_dim != n_embd:
self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
)
else:
self.emb_proj = nn.Identity()
else:
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
self.emb_proj = nn.Identity()
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x