Refactor model and training scripts: remove unused imports and add FactorizedHead class for improved modularity

This commit is contained in:
2026-01-09 12:01:52 +08:00
parent c70c3cd71e
commit aff0fe480b
2 changed files with 91 additions and 36 deletions

122
model.py
View File

@@ -5,12 +5,6 @@ from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
import torch.nn.functional as F
import torch.nn as nn
import torch
import sys
from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
class TabularEncoder(nn.Module):
@@ -265,28 +259,77 @@ class AutoDiscretization(nn.Module):
return emb
class FactorizedHead(nn.Module):
def __init__(
self,
n_embd: int,
n_disease: int,
n_dim: int,
rank: int = 16,
):
super().__init__()
self.n_embd = n_embd
self.n_disease = n_disease
self.n_dim = n_dim
self.rank = rank
self.disease_base_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, n_dim),
)
self.context_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank, bias=False),
)
self.disease_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank * n_dim, bias=False),
)
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
self._init_weights()
def _init_weights(self):
# init disease_base_proj: [LayerNorm, Linear]
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
nn.init.zeros_(self.disease_base_proj[1].bias)
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.zeros_(self.context_mod_proj[1].weight)
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
def forward(
self,
c: torch.Tensor, # (M, n_embd)
disease_embedding, # (n_disease, n_embd)
) -> torch.Tensor:
M = c.shape[0]
K = disease_embedding.shape[0]
assert K == self.n_disease
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
base_logits = base_logits.unsqueeze(
0).expand(M, -1, -1) # (M, K, n_dim)
u = self.context_mod_proj(c)
v = self.disease_mod_proj(disease_embedding)
v = v.view(K, self.rank, self.n_dim)
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
return base_logits + self.delta_scale * delta_logits
def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
B, L = event_seq.shape
device = event_seq.device
key_is_valid = (event_seq != 0)
cache = getattr(_build_time_padding_mask, "_cache", None)
if cache is None:
cache = {}
setattr(_build_time_padding_mask, "_cache", cache)
cache_key = (str(device), L)
causal = cache.get(cache_key)
if causal is None:
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
cache[cache_key] = causal
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
attn_mask = causal | key_pad # (B,1,L,L)
return attn_mask
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j <= t_i) # allow attending only to past or current
key_is_valid = (event_seq != 0) # disallow padded positions
allowed = time_mask & key_is_valid.unsqueeze(1)
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
return attn_mask.unsqueeze(1) # (B, 1, L, L)
class DelphiFork(nn.Module):
@@ -322,6 +365,7 @@ class DelphiFork(nn.Module):
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
@@ -356,7 +400,12 @@ class DelphiFork(nn.Module):
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
)
def forward(
self,
@@ -404,9 +453,11 @@ class DelphiFork(nn.Module):
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
]
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x
@@ -428,6 +479,7 @@ class SapDelphi(nn.Module):
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
@@ -438,6 +490,7 @@ class SapDelphi(nn.Module):
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.rank = rank
if pretrained_weights_path is not None:
print(
@@ -459,7 +512,6 @@ class SapDelphi(nn.Module):
self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
)
else:
self.emb_proj = nn.Identity()
@@ -491,7 +543,12 @@ class SapDelphi(nn.Module):
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
)
def forward(
self,
@@ -540,9 +597,12 @@ class SapDelphi(nn.Module):
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings_raw = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
] # (K, vocab_dim)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
disease_embeddings = self.emb_proj(disease_embeddings_raw)
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x

View File

@@ -16,11 +16,6 @@ import math
import sys
from dataclasses import asdict, dataclass, field
from typing import Literal, Sequence
from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
@dataclass