Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.

This commit is contained in:
2026-01-09 18:31:38 +08:00
parent 880fd53a4b
commit 209dde2299
3 changed files with 172 additions and 349 deletions

145
model.py
View File

@@ -259,64 +259,26 @@ class AutoDiscretization(nn.Module):
return emb
class FactorizedHead(nn.Module):
class SimpleHead(nn.Module):
def __init__(
self,
n_embd: int,
n_disease: int,
n_dim: int,
rank: int = 16,
out_dims: List[int],
):
super().__init__()
self.n_embd = n_embd
self.n_disease = n_disease
self.n_dim = n_dim
self.rank = rank
self.disease_base_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, n_dim),
self.out_dims = out_dims
total_out_dims = np.prod(out_dims)
self.net = nn.Sequential(
nn.Linear(n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, total_out_dims),
nn.LayerNorm(total_out_dims),
)
self.context_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank, bias=False),
)
self.disease_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank * n_dim, bias=False),
)
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
self._init_weights()
def _init_weights(self):
# init disease_base_proj: [LayerNorm, Linear]
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
nn.init.zeros_(self.disease_base_proj[1].bias)
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.zeros_(self.context_mod_proj[1].weight)
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
def forward(
self,
c: torch.Tensor, # (M, n_embd)
disease_embedding, # (n_disease, n_embd)
) -> torch.Tensor:
M = c.shape[0]
K = disease_embedding.shape[0]
assert K == self.n_disease
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
base_logits = base_logits.unsqueeze(
0).expand(M, -1, -1) # (M, K, n_dim)
u = self.context_mod_proj(c)
v = self.disease_mod_proj(disease_embedding)
v = v.view(K, self.rank, self.n_dim)
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
return base_logits + self.delta_scale * delta_logits
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
x = x.view(x.size(0), -1)
return x.view(-1, *self.out_dims)
def _build_time_padding_mask(
@@ -363,9 +325,6 @@ class DelphiFork(nn.Module):
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
@@ -373,7 +332,6 @@ class DelphiFork(nn.Module):
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
@@ -397,15 +355,21 @@ class DelphiFork(nn.Module):
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
def get_disease_embedding(self) -> torch.Tensor:
"""Get disease token embeddings for head computation.
Returns:
(n_disease, n_embd) tensor of disease token embeddings.
"""
device = self.token_embedding.weight.device
disease_ids = torch.arange(
self.n_tech_tokens,
self.n_tech_tokens + self.n_disease,
device=device,
)
disease_embs = self.token_embedding(disease_ids)
return disease_embs
def forward(
self,
@@ -414,8 +378,6 @@ class DelphiFork(nn.Module):
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
@@ -443,24 +405,13 @@ class DelphiFork(nn.Module):
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
]
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x
return x
class SapDelphi(nn.Module):
@@ -477,9 +428,6 @@ class SapDelphi(nn.Module):
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
@@ -489,8 +437,6 @@ class SapDelphi(nn.Module):
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.rank = rank
if pretrained_weights_path is not None:
print(
@@ -540,15 +486,22 @@ class SapDelphi(nn.Module):
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = FactorizedHead(
n_embd=n_embd,
n_disease=n_disease,
n_dim=n_dim,
rank=rank,
def get_disease_embedding(self) -> torch.Tensor:
"""Get disease token embeddings for head computation.
Returns:
(n_disease, n_embd) tensor of disease token embeddings.
"""
device = self.token_embedding.weight.device
disease_ids = torch.arange(
self.n_tech_tokens,
self.n_tech_tokens + self.n_disease,
device=device,
)
disease_embs = self.token_embedding(disease_ids)
disease_embs = self.emb_proj(disease_embs)
return disease_embs
def forward(
self,
@@ -557,8 +510,6 @@ class SapDelphi(nn.Module):
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D)
@@ -587,22 +538,10 @@ class SapDelphi(nn.Module):
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings_raw = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
] # (K, vocab_dim)
disease_embeddings = self.emb_proj(disease_embeddings_raw)
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x
return x