Refactor data preparation and add loss functions for model training
- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
This commit is contained in:
129
model.py
Normal file
129
model.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
from backbones import Block, ModernBlock, RMSNorm
|
||||
|
||||
class TabularEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_continuous: int,
|
||||
n_categorical: int,
|
||||
categorical_cardinalities: list[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None
|
||||
self.categorical_embeddings = nn.ModuleList([
|
||||
nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities
|
||||
]) if n_categorical > 0 else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
continuous_features: torch.Tensor | None,
|
||||
categorical_features: list[torch.Tensor] | None,
|
||||
) -> torch.Tensor:
|
||||
embeddings = []
|
||||
if self.continuous_proj is not None and continuous_features is not None:
|
||||
cont_emb = self.continuous_proj(continuous_features)
|
||||
embeddings.append(cont_emb)
|
||||
if self.categorical_embeddings is not None and categorical_features is not None:
|
||||
for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features):
|
||||
cat_emb = emb_layer(cat_feat)
|
||||
embeddings.append(cat_emb)
|
||||
if embeddings:
|
||||
return torch.sum(torch.stack(embeddings, dim=0), dim=0)
|
||||
else:
|
||||
raise ValueError("No features provided for TabularEncoder.")
|
||||
|
||||
def merge_two_sequences(
|
||||
time_seq1: torch.Tensor, # (B, L1)
|
||||
time_seq2: torch.Tensor, # (B, L2)
|
||||
seq1_embd: torch.Tensor, # (B, L1, D)
|
||||
seq2_embd: torch.Tensor, # (B, L2, D)
|
||||
) -> torch.Tensor:
|
||||
"""Merge two time sequences and their embeddings based on time order."""
|
||||
B, L1 = time_seq1.shape
|
||||
L2 = time_seq2.shape[1]
|
||||
merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2)
|
||||
merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D)
|
||||
|
||||
sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2)
|
||||
batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2)
|
||||
sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D)
|
||||
|
||||
return sorted_times, sorted_embd
|
||||
|
||||
|
||||
class DelphiFork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_continuous: int,
|
||||
n_categorical: int,
|
||||
categorical_cardinalities: list[int],
|
||||
pdrop: float = 0.1,
|
||||
token_pdrop: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.token_embedding = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
||||
self.sex_encoder = nn.Embedding(2, n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
self.covariate_encoder = TabularEncoder(
|
||||
n_embd=n_embd,
|
||||
n_continuous=n_continuous,
|
||||
n_categorical=n_categorical,
|
||||
categorical_cardinalities=categorical_cardinalities,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
pdrop=pdrop,
|
||||
) for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
self.head.weight = self.token_embedding.weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sex: torch.Tensor,
|
||||
event_seq: torch.Tensor,
|
||||
age_seq: torch.Tensor,
|
||||
cov_seq_time: torch.Tensor | None = None,
|
||||
cont_cov_seq: torch.Tensor | None = None,
|
||||
cat_cov_seq: list[torch.Tensor] | None = None,
|
||||
|
||||
) -> torch.Tensor:
|
||||
|
||||
event_emb = self.token_embedding(event_seq)
|
||||
age_emb = self.age_encoder(age_seq)
|
||||
sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd)
|
||||
|
||||
x = event_emb + age_emb + sex_emb
|
||||
if cov_seq_time is not None:
|
||||
covariate_emb = self.covariate_encoder(
|
||||
continuous_features=cont_cov_seq,
|
||||
categorical_features=cat_cov_seq,
|
||||
)
|
||||
covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb
|
||||
x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb)
|
||||
|
||||
x = self.token_dropout(x)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user