import torch import torch.nn as nn from torch.nn import functional as F from typing import Tuple # ============================================================================= # 1. Component Modules (Building Blocks) # ============================================================================= class Block(nn.Module): """ an unassuming Transformer block """ def __init__(self, n_embd: int, n_head: int, pdrop: float): super().__init__() self.n_head = n_head self.ln_1 = nn.LayerNorm(n_embd) self.attn = nn.MultiheadAttention(n_embd, n_head, dropout=pdrop, batch_first=True) self.ln_2 = nn.LayerNorm(n_embd) self.mlp = nn.ModuleDict(dict( c_fc = nn.Linear(n_embd, 4 * n_embd), c_proj = nn.Linear(4 * n_embd, n_embd), act = nn.GELU(), dropout = nn.Dropout(pdrop), )) m = self.mlp self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward self.resid_dropout = nn.Dropout(pdrop) def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: normed_x = self.ln_1(x) attn_mask = ~custom_mask attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) x = x + self.resid_dropout(attn_output) x = x + self.mlpf(self.ln_2(x)) return x class AgeSinusoidalEncoding(nn.Module): """ Encodes age using sinusoidal functions, similar to positional encodings in Transformers. This module creates a fixed-size embedding for an age value given in days. """ def __init__(self, embedding_dim: int): """ Initializes the AgeSinusoidalEncoding module. Args: embedding_dim (int): The dimensionality of the output embedding. Must be an even number. Raises: ValueError: If embedding_dim is not an even number. """ super().__init__() if embedding_dim % 2 != 0: raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}") self.embedding_dim = embedding_dim # Pre-calculate the divisor term for the sinusoidal formula. i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) divisor = torch.pow(10000, i / self.embedding_dim) self.register_buffer('divisor', divisor) def forward(self, t: torch.Tensor) -> torch.Tensor: """ Forward pass for the AgeSinusoidalEncoding. Args: t (torch.Tensor): A tensor of shape (batch_size, sequence_length) with dtype=torch.float32, representing age in days. Returns: torch.Tensor: The encoded age tensor of shape (batch_size, sequence_length, embedding_dim). """ t_years = t / 365.25 args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) output[:, :, 0::2] = torch.cos(args) output[:, :, 1::2] = torch.sin(args) return output class PiecewiseLinearEncoder(nn.Module): """ Encodes continuous variables using piecewise linear encoding. This module defines bins based on standard normal distribution quantiles, encodes an input by finding its bin, and calculates its position as a linear interpolation between boundaries. The result is projected to the final embedding dimension by a shared linear layer. """ def __init__(self, num_bins: int, embedding_dim: int): """ Initializes the PiecewiseLinearEncoder module. Args: num_bins (int): The number of bins for the encoding. embedding_dim (int): The dimensionality of the output embedding (D). """ super().__init__() if num_bins <= 0: raise ValueError("num_bins must be a positive integer.") self.num_bins = num_bins self.embedding_dim = embedding_dim if num_bins > 1: quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1) normal_dist = torch.distributions.normal.Normal(0, 1) boundaries = normal_dist.icdf(quantiles) else: boundaries = torch.tensor([]) boundaries = torch.cat([ torch.tensor([float('-inf')]), boundaries, torch.tensor([float('inf')]) ]) self.register_buffer('boundaries', boundaries) self.linear = nn.Linear(num_bins, embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the piecewise linear encoding. Args: x (torch.Tensor): Input tensor of shape (*, N), where * is any number of batch dimensions and N is the number of continuous features. Assumed to be pre-scaled. Returns: torch.Tensor: Encoded tensor of shape (*, N, D). """ original_shape = x.shape x = x.reshape(-1, original_shape[-1]) bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1 bin_indices = bin_indices.clamp(0, self.num_bins - 1) lower_bounds = self.boundaries[bin_indices] upper_bounds = self.boundaries[bin_indices + 1] delta = upper_bounds - lower_bounds + 1e-8 weight_upper = (x - lower_bounds) / delta weight_lower = 1.0 - weight_upper is_first_bin = (bin_indices == 0) is_last_bin = (bin_indices == self.num_bins - 1) weight_lower[is_first_bin] = 1.0 weight_upper[is_first_bin] = 0.0 weight_lower[is_last_bin] = 0.0 weight_upper[is_last_bin] = 1.0 encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype) encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1)) upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1) encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1)) encoded = encoded.view(*original_shape, self.num_bins) output = self.linear(encoded) return output # ============================================================================= # 2. Main Model Architectures # ============================================================================= class TimeAwareGPT2(nn.Module): """ A time-aware GPT-2 model with custom temporal features. """ def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): super().__init__() self.token_pdrop = token_pdrop self.wte = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd) self.drop = nn.Dropout(pdrop) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.n_embd = n_embd def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: """ Forward pass for the TimeAwareGPT2 model. Args: event_seq (torch.Tensor): Token indices of shape (B, L). time_seq (torch.Tensor): Timestamps for each event of shape (B, L). Returns: torch.Tensor: Logits of shape (B, L, vocab_size). """ B, L = event_seq.size() token_embeddings = self.wte(event_seq) if self.training and self.token_pdrop > 0: drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop token_embeddings[drop_mask] = 0.0 pos_embeddings = self.age_encoder(time_seq.float()) x = self.drop(token_embeddings + pos_embeddings) t_i = time_seq.unsqueeze(-1) t_j = time_seq.unsqueeze(1) time_mask = (t_j < t_i) padding_mask = (event_seq != 0).unsqueeze(1) combined_mask = time_mask & padding_mask is_row_all_zero = ~combined_mask.any(dim=-1) is_not_padding = (event_seq != 0) force_self_attention = is_row_all_zero & is_not_padding combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True for block in self.blocks: x = block(x, custom_mask=combined_mask) x = self.ln_f(x) logits = self.head(x) return logits def get_num_params(self) -> float: """ Returns the number of trainable parameters in the model in millions. """ return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 class CovariateAwareGPT2(nn.Module): """ Extends TimeAwareGPT2 to incorporate static and time-varying covariates. """ def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, num_bins: int): """ Initializes the CovariateAwareGPT2 model. Args: vocab_size (int): Size of the event vocabulary. n_embd (int): Embedding dimensionality. n_layer (int): Number of transformer layers. n_head (int): Number of attention heads. pdrop (float): Dropout probability for layers. token_pdrop (float): Dropout probability for input token embeddings. num_bins (int): Number of bins for the PiecewiseLinearEncoder. """ super().__init__() self.token_pdrop = token_pdrop self.wte = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd) self.drop = nn.Dropout(pdrop) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.n_embd = n_embd self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd) self.ln_f = nn.LayerNorm(2 * n_embd) self.head = nn.Sequential( nn.Linear(2 * n_embd, n_embd), nn.GELU(), nn.Linear(n_embd, vocab_size) ) def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor: """ Forward pass for the CovariateAwareGPT2 model. Args: x (torch.Tensor): Event sequence tensor of shape (B, L). t (torch.Tensor): Time sequence tensor of shape (B, L). cov (torch.Tensor): Covariate tensor of shape (B, N). cov_t (torch.Tensor): Covariate time tensor of shape (B). Returns: torch.Tensor: Logits of shape (B, L, vocab_size). """ B, L = x.size() cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1) cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1)) cov_embed = cov_encoded + cov_t_encoded token_embeddings = self.wte(x) if self.training and self.token_pdrop > 0: drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop token_embeddings[drop_mask] = 0.0 pos_embeddings = self.age_encoder(t.float()) seq_embed = self.drop(token_embeddings + pos_embeddings) t_i = t.unsqueeze(-1) t_j = t.unsqueeze(1) time_mask = (t_j < t_i) padding_mask = (x != 0).unsqueeze(1) combined_mask = time_mask & padding_mask is_row_all_zero = ~combined_mask.any(dim=-1) is_not_padding = (x != 0) force_self_attention = is_row_all_zero & is_not_padding combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True block_output = seq_embed for block in self.blocks: block_output = block(block_output, custom_mask=combined_mask) integrated_embed = torch.cat([block_output, cov_embed], dim=-1) final_output = self.ln_f(integrated_embed) logits = self.head(final_output) return logits def get_num_params(self) -> float: """ Returns the number of trainable parameters in the model in millions. """ return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 # ============================================================================= # 3. Loss Function # ============================================================================= class CombinedLoss(nn.Module): """ Computes a two-part loss: a standard cross-entropy loss for event type prediction and a survival analysis loss for event timing. """ def __init__(self, ignored_token_ids: list[int]): """ Initializes the CombinedLoss module. Args: ignored_token_ids (list[int]): A list of event type IDs to be excluded from all loss calculations. """ super().__init__() self.ignored_token_ids = ignored_token_ids def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculates the combined cross-entropy and survival loss. Args: logits (torch.Tensor): Raw model outputs of shape (B, L, N). x (torch.Tensor): Ground-truth event labels of shape (B, L). t (torch.Tensor): True time duration for each event, shape (B, L). Returns: A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). """ mask = torch.ones_like(x, dtype=torch.bool) for token_id in self.ignored_token_ids: mask = mask & (x != token_id) if not mask.any(): return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) logits_for_ce = logits.permute(0, 2, 1) per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') loss_ce = per_element_ce[mask].mean() intensity = torch.sum(torch.exp(logits), dim=2) per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) loss_survival = per_element_survival[mask].mean() return loss_ce, loss_survival