- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
129 lines
4.5 KiB
Python
129 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
|
from backbones import Block, ModernBlock, RMSNorm
|
|
|
|
class TabularEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
n_continuous: int,
|
|
n_categorical: int,
|
|
categorical_cardinalities: list[int],
|
|
):
|
|
super().__init__()
|
|
self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None
|
|
self.categorical_embeddings = nn.ModuleList([
|
|
nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities
|
|
]) if n_categorical > 0 else None
|
|
|
|
def forward(
|
|
self,
|
|
continuous_features: torch.Tensor | None,
|
|
categorical_features: list[torch.Tensor] | None,
|
|
) -> torch.Tensor:
|
|
embeddings = []
|
|
if self.continuous_proj is not None and continuous_features is not None:
|
|
cont_emb = self.continuous_proj(continuous_features)
|
|
embeddings.append(cont_emb)
|
|
if self.categorical_embeddings is not None and categorical_features is not None:
|
|
for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features):
|
|
cat_emb = emb_layer(cat_feat)
|
|
embeddings.append(cat_emb)
|
|
if embeddings:
|
|
return torch.sum(torch.stack(embeddings, dim=0), dim=0)
|
|
else:
|
|
raise ValueError("No features provided for TabularEncoder.")
|
|
|
|
def merge_two_sequences(
|
|
time_seq1: torch.Tensor, # (B, L1)
|
|
time_seq2: torch.Tensor, # (B, L2)
|
|
seq1_embd: torch.Tensor, # (B, L1, D)
|
|
seq2_embd: torch.Tensor, # (B, L2, D)
|
|
) -> torch.Tensor:
|
|
"""Merge two time sequences and their embeddings based on time order."""
|
|
B, L1 = time_seq1.shape
|
|
L2 = time_seq2.shape[1]
|
|
merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2)
|
|
merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D)
|
|
|
|
sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2)
|
|
batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2)
|
|
sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D)
|
|
|
|
return sorted_times, sorted_embd
|
|
|
|
|
|
class DelphiFork(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
n_embd: int,
|
|
n_head: int,
|
|
n_layer: int,
|
|
n_continuous: int,
|
|
n_categorical: int,
|
|
categorical_cardinalities: list[int],
|
|
pdrop: float = 0.1,
|
|
token_pdrop: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.token_embedding = nn.Embedding(vocab_size, n_embd)
|
|
self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
|
self.sex_encoder = nn.Embedding(2, n_embd)
|
|
self.token_dropout = nn.Dropout(token_pdrop)
|
|
self.covariate_encoder = TabularEncoder(
|
|
n_embd=n_embd,
|
|
n_continuous=n_continuous,
|
|
n_categorical=n_categorical,
|
|
categorical_cardinalities=categorical_cardinalities,
|
|
)
|
|
|
|
self.blocks = nn.ModuleList([
|
|
Block(
|
|
n_embd=n_embd,
|
|
n_head=n_head,
|
|
pdrop=pdrop,
|
|
) for _ in range(n_layer)
|
|
])
|
|
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
self.head.weight = self.token_embedding.weight
|
|
|
|
def forward(
|
|
self,
|
|
sex: torch.Tensor,
|
|
event_seq: torch.Tensor,
|
|
age_seq: torch.Tensor,
|
|
cov_seq_time: torch.Tensor | None = None,
|
|
cont_cov_seq: torch.Tensor | None = None,
|
|
cat_cov_seq: list[torch.Tensor] | None = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
event_emb = self.token_embedding(event_seq)
|
|
age_emb = self.age_encoder(age_seq)
|
|
sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd)
|
|
|
|
x = event_emb + age_emb + sex_emb
|
|
if cov_seq_time is not None:
|
|
covariate_emb = self.covariate_encoder(
|
|
continuous_features=cont_cov_seq,
|
|
categorical_features=cat_cov_seq,
|
|
)
|
|
covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb
|
|
x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb)
|
|
|
|
x = self.token_dropout(x)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
return logits
|
|
|
|
|
|
|
|
|