Refactor data preparation and add loss functions for model training

- Removed `prepare_data.py` as it is no longer needed.
- Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization.
- Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
This commit is contained in:
2025-12-05 00:54:56 +08:00
parent 9ca8909e3a
commit cb7adb70d9
6 changed files with 445 additions and 1486 deletions

129
model.py Normal file
View File

@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block, ModernBlock, RMSNorm
class TabularEncoder(nn.Module):
def __init__(
self,
n_embd: int,
n_continuous: int,
n_categorical: int,
categorical_cardinalities: list[int],
):
super().__init__()
self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None
self.categorical_embeddings = nn.ModuleList([
nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities
]) if n_categorical > 0 else None
def forward(
self,
continuous_features: torch.Tensor | None,
categorical_features: list[torch.Tensor] | None,
) -> torch.Tensor:
embeddings = []
if self.continuous_proj is not None and continuous_features is not None:
cont_emb = self.continuous_proj(continuous_features)
embeddings.append(cont_emb)
if self.categorical_embeddings is not None and categorical_features is not None:
for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features):
cat_emb = emb_layer(cat_feat)
embeddings.append(cat_emb)
if embeddings:
return torch.sum(torch.stack(embeddings, dim=0), dim=0)
else:
raise ValueError("No features provided for TabularEncoder.")
def merge_two_sequences(
time_seq1: torch.Tensor, # (B, L1)
time_seq2: torch.Tensor, # (B, L2)
seq1_embd: torch.Tensor, # (B, L1, D)
seq2_embd: torch.Tensor, # (B, L2, D)
) -> torch.Tensor:
"""Merge two time sequences and their embeddings based on time order."""
B, L1 = time_seq1.shape
L2 = time_seq2.shape[1]
merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2)
merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D)
sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2)
batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2)
sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D)
return sorted_times, sorted_embd
class DelphiFork(nn.Module):
def __init__(
self,
vocab_size: int,
n_embd: int,
n_head: int,
n_layer: int,
n_continuous: int,
n_categorical: int,
categorical_cardinalities: list[int],
pdrop: float = 0.1,
token_pdrop: float = 0.1,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
self.sex_encoder = nn.Embedding(2, n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
self.covariate_encoder = TabularEncoder(
n_embd=n_embd,
n_continuous=n_continuous,
n_categorical=n_categorical,
categorical_cardinalities=categorical_cardinalities,
)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.head.weight = self.token_embedding.weight
def forward(
self,
sex: torch.Tensor,
event_seq: torch.Tensor,
age_seq: torch.Tensor,
cov_seq_time: torch.Tensor | None = None,
cont_cov_seq: torch.Tensor | None = None,
cat_cov_seq: list[torch.Tensor] | None = None,
) -> torch.Tensor:
event_emb = self.token_embedding(event_seq)
age_emb = self.age_encoder(age_seq)
sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd)
x = event_emb + age_emb + sex_emb
if cov_seq_time is not None:
covariate_emb = self.covariate_encoder(
continuous_features=cont_cov_seq,
categorical_features=cat_cov_seq,
)
covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb
x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb)
x = self.token_dropout(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
return logits