Files
DeepHealth/model.py
Jiarui Li 6984b254b3 Add loss functions and model architecture for time-to-event prediction
- Implemented ExponentialNLLLoss and WeibullNLLLoss in losses.py for negative log-likelihood calculations.
- Developed TabularEncoder class in model.py for encoding tabular features.
- Created DelphiFork and SapDelphi classes in model.py for time-to-event prediction using transformer architecture.
- Added data preparation scripts in prepare_data.R and prepare_data.py for processing UK Biobank data, including handling field mappings and event data extraction.
2026-01-07 21:32:00 +08:00

441 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block
from typing import Optional, List
import numpy as np
class TabularEncoder(nn.Module):
"""
Encoder for tabular features (continuous and categorical).
Args:
n_embd (int): Embedding dimension.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
"""
def __init__(
self,
n_embd: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
):
super().__init__()
self.n_embd = n_embd
self.n_cont = n_cont
self.n_cate = n_cate
if n_cont > 0:
hidden = 2 * n_embd
self.cont_mlp = nn.Sequential(
nn.Linear(2 * n_cont, hidden),
nn.GELU(),
nn.Linear(hidden, n_embd),
)
else:
self.cont_mlp = None
if n_cate > 0:
assert len(cate_dims) == n_cate, \
"Length of cate_dims must match n_cate"
self.cate_embds = nn.ModuleList([
nn.Embedding(dim, n_embd) for dim in cate_dims
])
self.cate_mask_embds = nn.ModuleList([
nn.Embedding(2, n_embd) for _ in range(n_cate)
])
else:
self.cate_embds = None
self.cate_mask_embds = None
self.cont_mask_proj = (
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
)
self.film = nn.Sequential(
nn.Linear(n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, 2 * n_embd),
)
self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd)
# Zero-init the last layer of FiLM to start with identity modulation
with torch.no_grad():
last_linear = self.film[-1]
last_linear.weight.zero_()
last_linear.bias.zero_()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor],
) -> torch.Tensor:
if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError(
"TabularEncoder received no features but cannot infer (B, L)."
)
return torch.zeros(B, L, self.n_embd, device=device)
value_parts: List[torch.Tensor] = []
mask_parts: List[torch.Tensor] = []
if self.n_cont > 0 and cont_features is not None:
if cont_features.dim() != 3:
raise ValueError(
"cont_features must be 3D tensor (B, L, n_cont)")
B, L, D_cont = cont_features.shape
if D_cont != self.n_cont:
raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
cont_mask = (~torch.isnan(cont_features)).float()
cont_filled = torch.nan_to_num(cont_features, nan=0.0)
cont_joint = torch.cat([cont_filled, cont_mask], dim=-1)
h_cont_value = self.cont_mlp(cont_joint)
value_parts.append(h_cont_value)
if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask)
mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None:
if cate_features.dim() != 3:
raise ValueError(
"cate_features must be 3D tensor (B, L, n_cate)")
B, L, D_cate = cate_features.shape
if D_cate != self.n_cate:
raise ValueError(
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
for i in range(self.n_cate):
cate_feat = cate_features[:, :, i]
cate_embd = self.cate_embds[i]
cate_mask_embd = self.cate_mask_embds[i]
cate_value = cate_embd(
torch.clamp(cate_feat, min=0))
cate_mask = (cate_feat > 0).long()
cate_mask_value = cate_mask_embd(cate_mask)
value_parts.append(cate_value)
mask_parts.append(cate_mask_value)
if not value_parts:
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device)
h_value = torch.stack(value_parts, dim=0).mean(dim=0)
h_mask = torch.stack(mask_parts, dim=0).mean(dim=0)
h_mask_flat = h_mask.view(-1, self.n_embd)
film_params = self.film(h_mask_flat)
gamma_delta, beta = film_params.chunk(2, dim=-1)
gamma = 1.0 + gamma_delta
h_value_flat = h_value.view(-1, self.n_embd)
h_out = gamma * h_value_flat + beta
h_out = h_out.view(B, L, self.n_embd)
h_out = self.out_ln(h_out)
return h_out
def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j <= t_i) # allow attending only to past or current
key_is_valid = (event_seq != 0) # disallow padded positions
allowed = time_mask & key_is_valid.unsqueeze(1)
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
return attn_mask.unsqueeze(1) # (B, 1, L, L)
class DelphiFork(nn.Module):
"""
DelphiFork model for time-to-event prediction.
Args:
n_disease (int): Number of disease tokens.
n_tech_tokens (int): Number of technical tokens.
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
n_layer (int): Number of transformer layers.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
age_encoder_type (str): Type of age encoder ("sinusoidal" or "mlp").
pdrop (float): Dropout probability.
token_pdrop (float): Token dropout probability.
n_dim (int): Dimension of theta parameters.
"""
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x
class SapDelphi(nn.Module):
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
if pretrained_weights_path is not None:
print(
f"Loading pretrained embeddings from {pretrained_weights_path}...")
bert_weights = np.load(pretrained_weights_path)
bert_weights = torch.tensor(bert_weights, dtype=torch.float32)
vocab_dim = bert_weights.shape[1] # 通常是 768
pad_emb = torch.zeros(1, vocab_dim)
tech_embs = nn.init.normal_(torch.empty(
n_tech_tokens-1, vocab_dim))
full_emb_weights = torch.cat(
[pad_emb, tech_embs, bert_weights], dim=0)
self.token_embedding = nn.Embedding.from_pretrained(
full_emb_weights, freeze=freeze_embeddings)
print("Pretrained embeddings loaded.")
if vocab_dim != n_embd:
self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
)
else:
self.emb_proj = nn.Identity()
else:
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
self.emb_proj = nn.Identity()
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x