Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.
This commit is contained in:
145
model.py
145
model.py
@@ -259,64 +259,26 @@ class AutoDiscretization(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class FactorizedHead(nn.Module):
|
||||
class SimpleHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_disease: int,
|
||||
n_dim: int,
|
||||
rank: int = 16,
|
||||
out_dims: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.n_disease = n_disease
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
self.disease_base_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, n_dim),
|
||||
self.out_dims = out_dims
|
||||
total_out_dims = np.prod(out_dims)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(n_embd, n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(n_embd, total_out_dims),
|
||||
nn.LayerNorm(total_out_dims),
|
||||
)
|
||||
self.context_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank, bias=False),
|
||||
)
|
||||
self.disease_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank * n_dim, bias=False),
|
||||
)
|
||||
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# init disease_base_proj: [LayerNorm, Linear]
|
||||
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.disease_base_proj[1].bias)
|
||||
|
||||
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.zeros_(self.context_mod_proj[1].weight)
|
||||
|
||||
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor, # (M, n_embd)
|
||||
disease_embedding, # (n_disease, n_embd)
|
||||
) -> torch.Tensor:
|
||||
M = c.shape[0]
|
||||
K = disease_embedding.shape[0]
|
||||
assert K == self.n_disease
|
||||
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
|
||||
base_logits = base_logits.unsqueeze(
|
||||
0).expand(M, -1, -1) # (M, K, n_dim)
|
||||
u = self.context_mod_proj(c)
|
||||
v = self.disease_mod_proj(disease_embedding)
|
||||
v = v.view(K, self.rank, self.n_dim)
|
||||
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
|
||||
|
||||
return base_logits + self.delta_scale * delta_logits
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.net(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x.view(-1, *self.out_dims)
|
||||
|
||||
|
||||
def _build_time_padding_mask(
|
||||
@@ -363,9 +325,6 @@ class DelphiFork(nn.Module):
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = n_disease + n_tech_tokens
|
||||
@@ -373,7 +332,6 @@ class DelphiFork(nn.Module):
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
|
||||
self.token_embedding = nn.Embedding(
|
||||
self.vocab_size, n_embd, padding_idx=0)
|
||||
@@ -397,15 +355,21 @@ class DelphiFork(nn.Module):
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
def get_disease_embedding(self) -> torch.Tensor:
|
||||
"""Get disease token embeddings for head computation.
|
||||
|
||||
Returns:
|
||||
(n_disease, n_embd) tensor of disease token embeddings.
|
||||
"""
|
||||
device = self.token_embedding.weight.device
|
||||
disease_ids = torch.arange(
|
||||
self.n_tech_tokens,
|
||||
self.n_tech_tokens + self.n_disease,
|
||||
device=device,
|
||||
)
|
||||
disease_embs = self.token_embedding(disease_ids)
|
||||
return disease_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -414,8 +378,6 @@ class DelphiFork(nn.Module):
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, D)
|
||||
age_embds = self.age_encoder(time_seq) # (B, L, D)
|
||||
@@ -443,24 +405,13 @@ class DelphiFork(nn.Module):
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
]
|
||||
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
class SapDelphi(nn.Module):
|
||||
@@ -477,9 +428,6 @@ class SapDelphi(nn.Module):
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
pretrained_weights_path: Optional[str] = None, # 新增参数
|
||||
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
|
||||
):
|
||||
@@ -489,8 +437,6 @@ class SapDelphi(nn.Module):
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
if pretrained_weights_path is not None:
|
||||
print(
|
||||
@@ -540,15 +486,22 @@ class SapDelphi(nn.Module):
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
def get_disease_embedding(self) -> torch.Tensor:
|
||||
"""Get disease token embeddings for head computation.
|
||||
|
||||
Returns:
|
||||
(n_disease, n_embd) tensor of disease token embeddings.
|
||||
"""
|
||||
device = self.token_embedding.weight.device
|
||||
disease_ids = torch.arange(
|
||||
self.n_tech_tokens,
|
||||
self.n_tech_tokens + self.n_disease,
|
||||
device=device,
|
||||
)
|
||||
disease_embs = self.token_embedding(disease_ids)
|
||||
disease_embs = self.emb_proj(disease_embs)
|
||||
return disease_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -557,8 +510,6 @@ class SapDelphi(nn.Module):
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
|
||||
token_embds = self.emb_proj(token_embds) # (B, L, D)
|
||||
@@ -587,22 +538,10 @@ class SapDelphi(nn.Module):
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings_raw = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
] # (K, vocab_dim)
|
||||
|
||||
disease_embeddings = self.emb_proj(disease_embeddings_raw)
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user