From aff0fe480b7fb2f90c1a82d2d6413709224fc3bd Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 12:01:52 +0800 Subject: [PATCH] Refactor model and training scripts: remove unused imports and add FactorizedHead class for improved modularity --- model.py | 122 +++++++++++++++++++++++++++++++++++++++++-------------- train.py | 5 --- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/model.py b/model.py index df395b2..09f2c86 100644 --- a/model.py +++ b/model.py @@ -5,12 +5,6 @@ from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder import torch.nn.functional as F import torch.nn as nn import torch -import sys -from pathlib import Path - -_PROJECT_ROOT = Path(__file__).resolve().parent -if str(_PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(_PROJECT_ROOT)) class TabularEncoder(nn.Module): @@ -265,28 +259,77 @@ class AutoDiscretization(nn.Module): return emb +class FactorizedHead(nn.Module): + def __init__( + self, + n_embd: int, + n_disease: int, + n_dim: int, + rank: int = 16, + ): + super().__init__() + self.n_embd = n_embd + self.n_disease = n_disease + self.n_dim = n_dim + self.rank = rank + + self.disease_base_proj = nn.Sequential( + nn.LayerNorm(n_embd), + nn.Linear(n_embd, n_dim), + ) + self.context_mod_proj = nn.Sequential( + nn.LayerNorm(n_embd), + nn.Linear(n_embd, rank, bias=False), + ) + self.disease_mod_proj = nn.Sequential( + nn.LayerNorm(n_embd), + nn.Linear(n_embd, rank * n_dim, bias=False), + ) + self.delta_scale = nn.Parameter(torch.tensor(1e-3)) + + self._init_weights() + + def _init_weights(self): + # init disease_base_proj: [LayerNorm, Linear] + nn.init.normal_(self.disease_base_proj[1].weight, std=0.02) + nn.init.zeros_(self.disease_base_proj[1].bias) + + # init context_mod_proj: [LayerNorm, Linear(bias=False)] + nn.init.zeros_(self.context_mod_proj[1].weight) + + # init disease_mod_proj: [LayerNorm, Linear(bias=False)] + nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02) + + def forward( + self, + c: torch.Tensor, # (M, n_embd) + disease_embedding, # (n_disease, n_embd) + ) -> torch.Tensor: + M = c.shape[0] + K = disease_embedding.shape[0] + assert K == self.n_disease + base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim) + base_logits = base_logits.unsqueeze( + 0).expand(M, -1, -1) # (M, K, n_dim) + u = self.context_mod_proj(c) + v = self.disease_mod_proj(disease_embedding) + v = v.view(K, self.rank, self.n_dim) + delta_logits = torch.einsum('mr, krd -> mkd', u, v) + + return base_logits + self.delta_scale * delta_logits + + def _build_time_padding_mask( event_seq: torch.Tensor, time_seq: torch.Tensor, ) -> torch.Tensor: - B, L = event_seq.shape - device = event_seq.device - key_is_valid = (event_seq != 0) - - cache = getattr(_build_time_padding_mask, "_cache", None) - if cache is None: - cache = {} - setattr(_build_time_padding_mask, "_cache", cache) - cache_key = (str(device), L) - causal = cache.get(cache_key) - if causal is None: - causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1) - cache[cache_key] = causal - - causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L) - key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L) - attn_mask = causal | key_pad # (B,1,L,L) - return attn_mask + t_i = time_seq.unsqueeze(-1) + t_j = time_seq.unsqueeze(1) + time_mask = (t_j <= t_i) # allow attending only to past or current + key_is_valid = (event_seq != 0) # disallow padded positions + allowed = time_mask & key_is_valid.unsqueeze(1) + attn_mask = ~allowed # True means mask for scaled_dot_product_attention + return attn_mask.unsqueeze(1) # (B, 1, L, L) class DelphiFork(nn.Module): @@ -322,6 +365,7 @@ class DelphiFork(nn.Module): pdrop: float = 0.0, token_pdrop: float = 0.0, n_dim: int = 1, + rank: int = 16, ): super().__init__() self.vocab_size = n_disease + n_tech_tokens @@ -356,7 +400,12 @@ class DelphiFork(nn.Module): self.token_dropout = nn.Dropout(token_pdrop) # Head layers - self.theta_proj = nn.Linear(n_embd, n_disease * n_dim) + self.theta_proj = FactorizedHead( + n_embd=n_embd, + n_disease=n_disease, + n_dim=n_dim, + rank=rank, + ) def forward( self, @@ -404,9 +453,11 @@ class DelphiFork(nn.Module): if b_prev is not None and t_prev is not None: M = b_prev.numel() c = x[b_prev, t_prev] # (M, D) + disease_embeddings = self.token_embedding.weight[ + self.n_tech_tokens: self.n_tech_tokens + self.n_disease + ] - theta = self.theta_proj(c) # (M, N_disease * n_dim) - theta = theta.view(M, self.n_disease, self.n_dim) + theta = self.theta_proj(c, disease_embeddings) return theta else: return x @@ -428,6 +479,7 @@ class SapDelphi(nn.Module): pdrop: float = 0.0, token_pdrop: float = 0.0, n_dim: int = 1, + rank: int = 16, pretrained_weights_path: Optional[str] = None, # 新增参数 freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调 ): @@ -438,6 +490,7 @@ class SapDelphi(nn.Module): self.n_embd = n_embd self.n_head = n_head self.n_dim = n_dim + self.rank = rank if pretrained_weights_path is not None: print( @@ -459,7 +512,6 @@ class SapDelphi(nn.Module): self.emb_proj = nn.Sequential( nn.Linear(vocab_dim, n_embd, bias=False), nn.LayerNorm(n_embd), - nn.Dropout(pdrop), ) else: self.emb_proj = nn.Identity() @@ -491,7 +543,12 @@ class SapDelphi(nn.Module): self.token_dropout = nn.Dropout(token_pdrop) # Head layers - self.theta_proj = nn.Linear(n_embd, n_disease * n_dim) + self.theta_proj = FactorizedHead( + n_embd=n_embd, + n_disease=n_disease, + n_dim=n_dim, + rank=rank, + ) def forward( self, @@ -540,9 +597,12 @@ class SapDelphi(nn.Module): if b_prev is not None and t_prev is not None: M = b_prev.numel() c = x[b_prev, t_prev] # (M, D) + disease_embeddings_raw = self.token_embedding.weight[ + self.n_tech_tokens: self.n_tech_tokens + self.n_disease + ] # (K, vocab_dim) - theta = self.theta_proj(c) # (M, N_disease * n_dim) - theta = theta.view(M, self.n_disease, self.n_dim) + disease_embeddings = self.emb_proj(disease_embeddings_raw) + theta = self.theta_proj(c, disease_embeddings) return theta else: return x diff --git a/train.py b/train.py index 702c0e4..30a4452 100644 --- a/train.py +++ b/train.py @@ -16,11 +16,6 @@ import math import sys from dataclasses import asdict, dataclass, field from typing import Literal, Sequence -from pathlib import Path - -_PROJECT_ROOT = Path(__file__).resolve().parent -if str(_PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(_PROJECT_ROOT)) @dataclass