Refactor model and training scripts: remove unused imports and add FactorizedHead class for improved modularity
This commit is contained in:
122
model.py
122
model.py
@@ -5,12 +5,6 @@ from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
|
||||
class TabularEncoder(nn.Module):
|
||||
@@ -265,28 +259,77 @@ class AutoDiscretization(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class FactorizedHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_disease: int,
|
||||
n_dim: int,
|
||||
rank: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.n_disease = n_disease
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
self.disease_base_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, n_dim),
|
||||
)
|
||||
self.context_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank, bias=False),
|
||||
)
|
||||
self.disease_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank * n_dim, bias=False),
|
||||
)
|
||||
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# init disease_base_proj: [LayerNorm, Linear]
|
||||
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.disease_base_proj[1].bias)
|
||||
|
||||
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.zeros_(self.context_mod_proj[1].weight)
|
||||
|
||||
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor, # (M, n_embd)
|
||||
disease_embedding, # (n_disease, n_embd)
|
||||
) -> torch.Tensor:
|
||||
M = c.shape[0]
|
||||
K = disease_embedding.shape[0]
|
||||
assert K == self.n_disease
|
||||
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
|
||||
base_logits = base_logits.unsqueeze(
|
||||
0).expand(M, -1, -1) # (M, K, n_dim)
|
||||
u = self.context_mod_proj(c)
|
||||
v = self.disease_mod_proj(disease_embedding)
|
||||
v = v.view(K, self.rank, self.n_dim)
|
||||
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
|
||||
|
||||
return base_logits + self.delta_scale * delta_logits
|
||||
|
||||
|
||||
def _build_time_padding_mask(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, L = event_seq.shape
|
||||
device = event_seq.device
|
||||
key_is_valid = (event_seq != 0)
|
||||
|
||||
cache = getattr(_build_time_padding_mask, "_cache", None)
|
||||
if cache is None:
|
||||
cache = {}
|
||||
setattr(_build_time_padding_mask, "_cache", cache)
|
||||
cache_key = (str(device), L)
|
||||
causal = cache.get(cache_key)
|
||||
if causal is None:
|
||||
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
|
||||
cache[cache_key] = causal
|
||||
|
||||
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
|
||||
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
|
||||
attn_mask = causal | key_pad # (B,1,L,L)
|
||||
return attn_mask
|
||||
t_i = time_seq.unsqueeze(-1)
|
||||
t_j = time_seq.unsqueeze(1)
|
||||
time_mask = (t_j <= t_i) # allow attending only to past or current
|
||||
key_is_valid = (event_seq != 0) # disallow padded positions
|
||||
allowed = time_mask & key_is_valid.unsqueeze(1)
|
||||
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
|
||||
return attn_mask.unsqueeze(1) # (B, 1, L, L)
|
||||
|
||||
|
||||
class DelphiFork(nn.Module):
|
||||
@@ -322,6 +365,7 @@ class DelphiFork(nn.Module):
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = n_disease + n_tech_tokens
|
||||
@@ -356,7 +400,12 @@ class DelphiFork(nn.Module):
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -404,9 +453,11 @@ class DelphiFork(nn.Module):
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
]
|
||||
|
||||
theta = self.theta_proj(c) # (M, N_disease * n_dim)
|
||||
theta = theta.view(M, self.n_disease, self.n_dim)
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
@@ -428,6 +479,7 @@ class SapDelphi(nn.Module):
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
pretrained_weights_path: Optional[str] = None, # 新增参数
|
||||
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
|
||||
):
|
||||
@@ -438,6 +490,7 @@ class SapDelphi(nn.Module):
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
if pretrained_weights_path is not None:
|
||||
print(
|
||||
@@ -459,7 +512,6 @@ class SapDelphi(nn.Module):
|
||||
self.emb_proj = nn.Sequential(
|
||||
nn.Linear(vocab_dim, n_embd, bias=False),
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Dropout(pdrop),
|
||||
)
|
||||
else:
|
||||
self.emb_proj = nn.Identity()
|
||||
@@ -491,7 +543,12 @@ class SapDelphi(nn.Module):
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -540,9 +597,12 @@ class SapDelphi(nn.Module):
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings_raw = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
] # (K, vocab_dim)
|
||||
|
||||
theta = self.theta_proj(c) # (M, N_disease * n_dim)
|
||||
theta = theta.view(M, self.n_disease, self.n_dim)
|
||||
disease_embeddings = self.emb_proj(disease_embeddings_raw)
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user