Files
DeepHealth/model.py

543 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block
from typing import Optional, List
import numpy as np
class TabularEncoder(nn.Module):
"""
Encoder for tabular features (continuous and categorical).
Args:
n_embd (int): Embedding dimension.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
n_bins (int): Number of soft bins for continuous AutoDiscretization.
"""
def __init__(
self,
n_embd: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
n_bins: int = 16,
):
super().__init__()
self.n_embd = n_embd
self.n_cont = n_cont
self.n_cate = n_cate
# Continuous feature path
# - BatchNorm on raw (NaN-filled) continuous values
# - AutoDiscretization (soft binning) per feature
if n_cont > 0:
self.cont_bn = nn.BatchNorm1d(n_cont)
self.cont_discretizer = AutoDiscretization(
n_features=n_cont,
n_bins=n_bins,
n_embd=n_embd,
)
else:
self.cont_bn = None
self.cont_discretizer = None
if n_cate > 0:
assert len(cate_dims) == n_cate, \
"Length of cate_dims must match n_cate"
self.cate_embds = nn.ModuleList([
nn.Embedding(dim, n_embd) for dim in cate_dims
])
self.cate_mask_embds = nn.ModuleList([
nn.Embedding(2, n_embd) for _ in range(n_cate)
])
else:
self.cate_embds = None
self.cate_mask_embds = None
self.cont_mask_proj = (
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
)
# Fuse aggregated value + aggregated mask via MLP
self.fuse_mlp = nn.Sequential(
nn.Linear(2 * n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, n_embd),
)
self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor],
) -> torch.Tensor:
"""Encode tabular features into a per-timestep embedding.
Inputs:
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
Returns:
(B, L, n_embd) encoded embedding.
"""
if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError(
"TabularEncoder received no features but cannot infer (B, L)."
)
return torch.zeros(B, L, self.n_embd, device=device)
value_parts: List[torch.Tensor] = []
mask_parts: List[torch.Tensor] = []
if self.n_cont > 0 and cont_features is not None:
if cont_features.dim() != 3:
raise ValueError(
"cont_features must be 3D tensor (B, L, n_cont)")
B, L, D_cont = cont_features.shape
if D_cont != self.n_cont:
raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
# Missingness mask: 1 for valid, 0 for missing
cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
# BatchNorm cannot handle NaNs; fill missing with 0 before BN.
cont_filled = torch.nan_to_num(
cont_features, nan=0.0) # (B, L, n_cont)
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
cont_flat = cont_filled.reshape(-1, self.n_cont)
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
# Mask-out missing continuous features before aggregating across features
# (B, L, n_cont, n_embd)
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
denom = cont_mask.sum(
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
value_parts.append(h_cont_value)
# Explicit continuous mask embedding (fused later)
if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None:
if cate_features.dim() != 3:
raise ValueError(
"cate_features must be 3D tensor (B, L, n_cate)")
B, L, D_cate = cate_features.shape
if D_cate != self.n_cate:
raise ValueError(
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
for i in range(self.n_cate):
cate_feat = cate_features[:, :, i]
cate_embd = self.cate_embds[i]
cate_mask_embd = self.cate_mask_embds[i]
cate_value = cate_embd(
torch.clamp(cate_feat, min=0))
cate_mask = (cate_feat > 0).long()
cate_mask_value = cate_mask_embd(cate_mask)
value_parts.append(cate_value)
mask_parts.append(cate_mask_value)
if not value_parts:
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device)
# Aggregate across feature groups (continuous block counts as one part;
# each categorical feature counts as one part).
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
if mask_parts:
h_mask = torch.stack(mask_parts, dim=0).mean(
dim=0) # (B, L, n_embd)
else:
h_mask = torch.zeros_like(h_value)
# Fuse by concatenation + MLP projection
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
h_out = self.out_ln(h_out)
return h_out
class AutoDiscretization(nn.Module):
"""AutoDiscretization / soft-binning for continuous tabular scalars.
For each feature scalar $x$, compute a soft assignment over `n_bins`:
p = softmax(x * w + b)
Then compute the embedding as a weighted sum of learnable bin embeddings:
emb = sum_k p_k * E_k
Shapes:
Input: (N, n_features)
Output: (N, n_features, n_embd)
"""
def __init__(self, n_features: int, n_bins: int, n_embd: int):
super().__init__()
if n_features <= 0:
raise ValueError("n_features must be > 0")
if n_bins <= 1:
raise ValueError("n_bins must be > 1")
if n_embd <= 0:
raise ValueError("n_embd must be > 0")
self.n_features = n_features
self.n_bins = n_bins
self.n_embd = n_embd
# Per-feature, per-bin affine transform to produce logits
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
# Learnable embeddings for each (feature, bin)
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.bias)
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 2:
raise ValueError(
"AutoDiscretization expects input of shape (N, n_features)")
if x.size(1) != self.n_features:
raise ValueError(
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
)
# logits: (N, n_features, n_bins)
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
self.bias.unsqueeze(0)
probs = torch.softmax(logits, dim=-1)
# Weighted sum over bins -> (N, n_features, n_embd)
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
return emb
def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
B, L = event_seq.shape
device = event_seq.device
key_is_valid = (event_seq != 0)
cache = getattr(_build_time_padding_mask, "_cache", None)
if cache is None:
cache = {}
setattr(_build_time_padding_mask, "_cache", cache)
cache_key = (str(device), L)
causal = cache.get(cache_key)
if causal is None:
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
cache[cache_key] = causal
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
attn_mask = causal | key_pad # (B,1,L,L)
return attn_mask
class DelphiFork(nn.Module):
"""
DelphiFork model for time-to-event prediction.
Args:
n_disease (int): Number of disease tokens.
n_tech_tokens (int): Number of technical tokens.
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
n_layer (int): Number of transformer layers.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
age_encoder_type (str): Type of age encoder ("sinusoidal" or "mlp").
pdrop (float): Dropout probability.
token_pdrop (float): Token dropout probability.
n_dim (int): Dimension of theta parameters.
"""
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x
class SapDelphi(nn.Module):
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
if pretrained_weights_path is not None:
print(
f"Loading pretrained embeddings from {pretrained_weights_path}...")
bert_weights = np.load(pretrained_weights_path)
bert_weights = torch.tensor(bert_weights, dtype=torch.float32)
vocab_dim = bert_weights.shape[1] # 通常是 768
pad_emb = torch.zeros(1, vocab_dim)
tech_embs = nn.init.normal_(torch.empty(
n_tech_tokens-1, vocab_dim))
full_emb_weights = torch.cat(
[pad_emb, tech_embs, bert_weights], dim=0)
self.token_embedding = nn.Embedding.from_pretrained(
full_emb_weights, freeze=freeze_embeddings)
print("Pretrained embeddings loaded.")
if vocab_dim != n_embd:
self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
)
else:
self.emb_proj = nn.Identity()
else:
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
self.emb_proj = nn.Identity()
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x