import torch import torch.nn as nn from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder from backbones import Block, ModernBlock, RMSNorm class TabularEncoder(nn.Module): def __init__( self, n_embd: int, n_continuous: int, n_categorical: int, categorical_cardinalities: list[int], ): super().__init__() self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None self.categorical_embeddings = nn.ModuleList([ nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities ]) if n_categorical > 0 else None def forward( self, continuous_features: torch.Tensor | None, categorical_features: list[torch.Tensor] | None, ) -> torch.Tensor: embeddings = [] if self.continuous_proj is not None and continuous_features is not None: cont_emb = self.continuous_proj(continuous_features) embeddings.append(cont_emb) if self.categorical_embeddings is not None and categorical_features is not None: for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features): cat_emb = emb_layer(cat_feat) embeddings.append(cat_emb) if embeddings: return torch.sum(torch.stack(embeddings, dim=0), dim=0) else: raise ValueError("No features provided for TabularEncoder.") def merge_two_sequences( time_seq1: torch.Tensor, # (B, L1) time_seq2: torch.Tensor, # (B, L2) seq1_embd: torch.Tensor, # (B, L1, D) seq2_embd: torch.Tensor, # (B, L2, D) ) -> torch.Tensor: """Merge two time sequences and their embeddings based on time order.""" B, L1 = time_seq1.shape L2 = time_seq2.shape[1] merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2) merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D) sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2) batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2) sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D) return sorted_times, sorted_embd class DelphiFork(nn.Module): def __init__( self, vocab_size: int, n_embd: int, n_head: int, n_layer: int, n_continuous: int, n_categorical: int, categorical_cardinalities: list[int], pdrop: float = 0.1, token_pdrop: float = 0.1, ): super().__init__() self.token_embedding = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd) self.sex_encoder = nn.Embedding(2, n_embd) self.token_dropout = nn.Dropout(token_pdrop) self.covariate_encoder = TabularEncoder( n_embd=n_embd, n_continuous=n_continuous, n_categorical=n_categorical, categorical_cardinalities=categorical_cardinalities, ) self.blocks = nn.ModuleList([ Block( n_embd=n_embd, n_head=n_head, pdrop=pdrop, ) for _ in range(n_layer) ]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.head.weight = self.token_embedding.weight def forward( self, sex: torch.Tensor, event_seq: torch.Tensor, age_seq: torch.Tensor, cov_seq_time: torch.Tensor | None = None, cont_cov_seq: torch.Tensor | None = None, cat_cov_seq: list[torch.Tensor] | None = None, ) -> torch.Tensor: event_emb = self.token_embedding(event_seq) age_emb = self.age_encoder(age_seq) sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd) x = event_emb + age_emb + sex_emb if cov_seq_time is not None: covariate_emb = self.covariate_encoder( continuous_features=cont_cov_seq, categorical_features=cat_cov_seq, ) covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb) x = self.token_dropout(x) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.head(x) return logits