Refactor model and training scripts: remove unused imports and add FactorizedHead class for improved modularity

This commit is contained in:
2026-01-09 12:01:52 +08:00
parent c70c3cd71e
commit aff0fe480b
2 changed files with 91 additions and 36 deletions

122
model.py
View File

@@ -5,12 +5,6 @@ from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch import torch
import sys
from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
class TabularEncoder(nn.Module): class TabularEncoder(nn.Module):
@@ -265,28 +259,77 @@ class AutoDiscretization(nn.Module):
return emb return emb
class FactorizedHead(nn.Module):
def __init__(
self,
n_embd: int,
n_disease: int,
n_dim: int,
rank: int = 16,
):
super().__init__()
self.n_embd = n_embd
self.n_disease = n_disease
self.n_dim = n_dim
self.rank = rank
self.disease_base_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, n_dim),
)
self.context_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank, bias=False),
)
self.disease_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank * n_dim, bias=False),
)
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
self._init_weights()
def _init_weights(self):
# init disease_base_proj: [LayerNorm, Linear]
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
nn.init.zeros_(self.disease_base_proj[1].bias)
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.zeros_(self.context_mod_proj[1].weight)
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
def forward(
self,
c: torch.Tensor, # (M, n_embd)
disease_embedding, # (n_disease, n_embd)
) -> torch.Tensor:
M = c.shape[0]
K = disease_embedding.shape[0]
assert K == self.n_disease
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
base_logits = base_logits.unsqueeze(
0).expand(M, -1, -1) # (M, K, n_dim)
u = self.context_mod_proj(c)
v = self.disease_mod_proj(disease_embedding)
v = v.view(K, self.rank, self.n_dim)
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
return base_logits + self.delta_scale * delta_logits
def _build_time_padding_mask( def _build_time_padding_mask(
event_seq: torch.Tensor, event_seq: torch.Tensor,
time_seq: torch.Tensor, time_seq: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
B, L = event_seq.shape t_i = time_seq.unsqueeze(-1)
device = event_seq.device t_j = time_seq.unsqueeze(1)
key_is_valid = (event_seq != 0) time_mask = (t_j <= t_i) # allow attending only to past or current
key_is_valid = (event_seq != 0) # disallow padded positions
cache = getattr(_build_time_padding_mask, "_cache", None) allowed = time_mask & key_is_valid.unsqueeze(1)
if cache is None: attn_mask = ~allowed # True means mask for scaled_dot_product_attention
cache = {} return attn_mask.unsqueeze(1) # (B, 1, L, L)
setattr(_build_time_padding_mask, "_cache", cache)
cache_key = (str(device), L)
causal = cache.get(cache_key)
if causal is None:
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
cache[cache_key] = causal
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
attn_mask = causal | key_pad # (B,1,L,L)
return attn_mask
class DelphiFork(nn.Module): class DelphiFork(nn.Module):
@@ -322,6 +365,7 @@ class DelphiFork(nn.Module):
pdrop: float = 0.0, pdrop: float = 0.0,
token_pdrop: float = 0.0, token_pdrop: float = 0.0,
n_dim: int = 1, n_dim: int = 1,
rank: int = 16,
): ):
super().__init__() super().__init__()
self.vocab_size = n_disease + n_tech_tokens self.vocab_size = n_disease + n_tech_tokens
@@ -356,7 +400,12 @@ class DelphiFork(nn.Module):
self.token_dropout = nn.Dropout(token_pdrop) self.token_dropout = nn.Dropout(token_pdrop)
# Head layers # Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim) self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
)
def forward( def forward(
self, self,
@@ -404,9 +453,11 @@ class DelphiFork(nn.Module):
if b_prev is not None and t_prev is not None: if b_prev is not None and t_prev is not None:
M = b_prev.numel() M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D) c = x[b_prev, t_prev] # (M, D)
disease_embeddings = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
]
theta = self.theta_proj(c) # (M, N_disease * n_dim) theta = self.theta_proj(c, disease_embeddings)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta return theta
else: else:
return x return x
@@ -428,6 +479,7 @@ class SapDelphi(nn.Module):
pdrop: float = 0.0, pdrop: float = 0.0,
token_pdrop: float = 0.0, token_pdrop: float = 0.0,
n_dim: int = 1, n_dim: int = 1,
rank: int = 16,
pretrained_weights_path: Optional[str] = None, # 新增参数 pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调 freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
): ):
@@ -438,6 +490,7 @@ class SapDelphi(nn.Module):
self.n_embd = n_embd self.n_embd = n_embd
self.n_head = n_head self.n_head = n_head
self.n_dim = n_dim self.n_dim = n_dim
self.rank = rank
if pretrained_weights_path is not None: if pretrained_weights_path is not None:
print( print(
@@ -459,7 +512,6 @@ class SapDelphi(nn.Module):
self.emb_proj = nn.Sequential( self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False), nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd), nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
) )
else: else:
self.emb_proj = nn.Identity() self.emb_proj = nn.Identity()
@@ -491,7 +543,12 @@ class SapDelphi(nn.Module):
self.token_dropout = nn.Dropout(token_pdrop) self.token_dropout = nn.Dropout(token_pdrop)
# Head layers # Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim) self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
)
def forward( def forward(
self, self,
@@ -540,9 +597,12 @@ class SapDelphi(nn.Module):
if b_prev is not None and t_prev is not None: if b_prev is not None and t_prev is not None:
M = b_prev.numel() M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D) c = x[b_prev, t_prev] # (M, D)
disease_embeddings_raw = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
] # (K, vocab_dim)
theta = self.theta_proj(c) # (M, N_disease * n_dim) disease_embeddings = self.emb_proj(disease_embeddings_raw)
theta = theta.view(M, self.n_disease, self.n_dim) theta = self.theta_proj(c, disease_embeddings)
return theta return theta
else: else:
return x return x

View File

@@ -16,11 +16,6 @@ import math
import sys import sys
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence from typing import Literal, Sequence
from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
@dataclass @dataclass