diff --git a/models.py b/models.py index 3309fc0..3fa8ffe 100644 --- a/models.py +++ b/models.py @@ -3,6 +3,10 @@ import torch.nn as nn from torch.nn import functional as F from typing import Tuple +# ============================================================================= +# 1. Component Modules (Building Blocks) +# ============================================================================= + class Block(nn.Module): """ an unassuming Transformer block """ @@ -58,14 +62,8 @@ class AgeSinusoidalEncoding(nn.Module): self.embedding_dim = embedding_dim # Pre-calculate the divisor term for the sinusoidal formula. - # The formula for the divisor is 10000^(2i/D), where D is the - # embedding_dim and i is the index for each pair of dimensions. - # i ranges from 0 to D/2 - 1. i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) divisor = torch.pow(10000, i / self.embedding_dim) - - # Register the divisor as a non-trainable buffer. This ensures it is - # moved to the correct device (e.g., GPU) along with the model. self.register_buffer('divisor', divisor) def forward(self, t: torch.Tensor) -> torch.Tensor: @@ -80,28 +78,100 @@ class AgeSinusoidalEncoding(nn.Module): torch.Tensor: The encoded age tensor of shape (batch_size, sequence_length, embedding_dim). """ - # 1. Unit Conversion: Convert age from days to years. - # We use 365.25 to account for leap years. t_years = t / 365.25 - - # 2. Argument Calculation: Calculate the arguments for the sin/cos functions. - # The shapes are broadcast to (B, L, D/2). - # Input t_years: (B, L) -> unsqueezed to (B, L, 1) - # Divisor: (D/2) -> viewed as (1, 1, D/2) args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) - - # 3. Sinusoidal Application: Create the final output tensor. - # Initialize an empty tensor to store the embeddings. output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) - - # Assign cosine of the arguments to the even indices. output[:, :, 0::2] = torch.cos(args) - - # Assign sine of the arguments to the odd indices. output[:, :, 1::2] = torch.sin(args) - return output +class PiecewiseLinearEncoder(nn.Module): + """ + Encodes continuous variables using piecewise linear encoding. + + This module defines bins based on standard normal distribution quantiles, + encodes an input by finding its bin, and calculates its position as a + linear interpolation between boundaries. The result is projected to the + final embedding dimension by a shared linear layer. + """ + + def __init__(self, num_bins: int, embedding_dim: int): + """ + Initializes the PiecewiseLinearEncoder module. + + Args: + num_bins (int): The number of bins for the encoding. + embedding_dim (int): The dimensionality of the output embedding (D). + """ + super().__init__() + if num_bins <= 0: + raise ValueError("num_bins must be a positive integer.") + self.num_bins = num_bins + self.embedding_dim = embedding_dim + + if num_bins > 1: + quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1) + normal_dist = torch.distributions.normal.Normal(0, 1) + boundaries = normal_dist.icdf(quantiles) + else: + boundaries = torch.tensor([]) + + boundaries = torch.cat([ + torch.tensor([float('-inf')]), + boundaries, + torch.tensor([float('inf')]) + ]) + self.register_buffer('boundaries', boundaries) + + self.linear = nn.Linear(num_bins, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the piecewise linear encoding. + + Args: + x (torch.Tensor): Input tensor of shape (*, N), where * is any + number of batch dimensions and N is the number of continuous + features. Assumed to be pre-scaled. + + Returns: + torch.Tensor: Encoded tensor of shape (*, N, D). + """ + original_shape = x.shape + x = x.reshape(-1, original_shape[-1]) + + bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1 + bin_indices = bin_indices.clamp(0, self.num_bins - 1) + + lower_bounds = self.boundaries[bin_indices] + upper_bounds = self.boundaries[bin_indices + 1] + delta = upper_bounds - lower_bounds + 1e-8 + + weight_upper = (x - lower_bounds) / delta + weight_lower = 1.0 - weight_upper + + is_first_bin = (bin_indices == 0) + is_last_bin = (bin_indices == self.num_bins - 1) + + weight_lower[is_first_bin] = 1.0 + weight_upper[is_first_bin] = 0.0 + weight_lower[is_last_bin] = 0.0 + weight_upper[is_last_bin] = 1.0 + + encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype) + encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1)) + + upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1) + encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1)) + + encoded = encoded.view(*original_shape, self.num_bins) + output = self.linear(encoded) + return output + +# ============================================================================= +# 2. Main Model Architectures +# ============================================================================= + class TimeAwareGPT2(nn.Module): """ A time-aware GPT-2 model with custom temporal features. @@ -111,18 +181,12 @@ class TimeAwareGPT2(nn.Module): super().__init__() self.token_pdrop = token_pdrop - # Token and positional embeddings self.wte = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd) self.drop = nn.Dropout(pdrop) - - # Transformer blocks self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) - - # Final layer norm and linear head self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) - self.n_embd = n_embd def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: @@ -138,53 +202,30 @@ class TimeAwareGPT2(nn.Module): """ B, L = event_seq.size() - # 1. Get token embeddings token_embeddings = self.wte(event_seq) - - # 2. Apply token dropout (only during training) if self.training and self.token_pdrop > 0: - # Create a mask to randomly zero out entire token embedding vectors drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop token_embeddings[drop_mask] = 0.0 - # 3. Get positional embeddings from time sequence pos_embeddings = self.age_encoder(time_seq.float()) - - # 4. Combine embeddings and apply dropout x = self.drop(token_embeddings + pos_embeddings) - # 5. Generate attention mask - # The attention mask combines two conditions: - # a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i]. - # b) Padding mask: Do not attend to positions where the event token is 0. - - # a) Time-based causal mask - t_i = time_seq.unsqueeze(-1) # (B, L, 1) - t_j = time_seq.unsqueeze(1) # (B, 1, L) + t_i = time_seq.unsqueeze(-1) + t_j = time_seq.unsqueeze(1) time_mask = (t_j < t_i) - - # b) Padding mask (prevents attending to key positions that are padding) - padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) - - # Combine the masks. A position (j) can be attended to by a query (i) only if - # it's in the past (time_mask) AND it's not a padding token (padding_mask). + padding_mask = (event_seq != 0).unsqueeze(1) combined_mask = time_mask & padding_mask - # Forcibly allow a non-padding token to attend to itself if it cannot attend to any other token. - # This prevents NaN issues in the attention mechanism for the first token in a sequence. is_row_all_zero = ~combined_mask.any(dim=-1) is_not_padding = (event_seq != 0) force_self_attention = is_row_all_zero & is_not_padding combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True - # 6. Pass through transformer blocks for block in self.blocks: x = block(x, custom_mask=combined_mask) - # 7. Final layer norm and projection to vocab size x = self.ln_f(x) logits = self.head(x) - return logits def get_num_params(self) -> float: @@ -193,6 +234,99 @@ class TimeAwareGPT2(nn.Module): """ return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 +class CovariateAwareGPT2(nn.Module): + """ + Extends TimeAwareGPT2 to incorporate static and time-varying covariates. + """ + + def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, + pdrop: float, token_pdrop: float, num_bins: int): + """ + Initializes the CovariateAwareGPT2 model. + + Args: + vocab_size (int): Size of the event vocabulary. + n_embd (int): Embedding dimensionality. + n_layer (int): Number of transformer layers. + n_head (int): Number of attention heads. + pdrop (float): Dropout probability for layers. + token_pdrop (float): Dropout probability for input token embeddings. + num_bins (int): Number of bins for the PiecewiseLinearEncoder. + """ + super().__init__() + self.token_pdrop = token_pdrop + + self.wte = nn.Embedding(vocab_size, n_embd) + self.age_encoder = AgeSinusoidalEncoding(n_embd) + self.drop = nn.Dropout(pdrop) + self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) + self.n_embd = n_embd + self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd) + + self.ln_f = nn.LayerNorm(2 * n_embd) + self.head = nn.Sequential( + nn.Linear(2 * n_embd, n_embd), + nn.GELU(), + nn.Linear(n_embd, vocab_size) + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the CovariateAwareGPT2 model. + + Args: + x (torch.Tensor): Event sequence tensor of shape (B, L). + t (torch.Tensor): Time sequence tensor of shape (B, L). + cov (torch.Tensor): Covariate tensor of shape (B, N). + cov_t (torch.Tensor): Covariate time tensor of shape (B). + + Returns: + torch.Tensor: Logits of shape (B, L, vocab_size). + """ + B, L = x.size() + + cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1) + cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1)) + cov_embed = cov_encoded + cov_t_encoded + + token_embeddings = self.wte(x) + if self.training and self.token_pdrop > 0: + drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop + token_embeddings[drop_mask] = 0.0 + + pos_embeddings = self.age_encoder(t.float()) + seq_embed = self.drop(token_embeddings + pos_embeddings) + + t_i = t.unsqueeze(-1) + t_j = t.unsqueeze(1) + time_mask = (t_j < t_i) + padding_mask = (x != 0).unsqueeze(1) + combined_mask = time_mask & padding_mask + is_row_all_zero = ~combined_mask.any(dim=-1) + is_not_padding = (x != 0) + force_self_attention = is_row_all_zero & is_not_padding + combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True + + block_output = seq_embed + for block in self.blocks: + block_output = block(block_output, custom_mask=combined_mask) + + integrated_embed = torch.cat([block_output, cov_embed], dim=-1) + + final_output = self.ln_f(integrated_embed) + logits = self.head(final_output) + return logits + + def get_num_params(self) -> float: + """ + Returns the number of trainable parameters in the model in millions. + """ + return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 + +# ============================================================================= +# 3. Loss Function +# ============================================================================= + class CombinedLoss(nn.Module): """ Computes a two-part loss: a standard cross-entropy loss for event type @@ -222,35 +356,19 @@ class CombinedLoss(nn.Module): Returns: A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). """ - # 1. Create a mask to filter out ignored token IDs from loss calculation. - # An element is True if the corresponding label in x is NOT in the ignored list. mask = torch.ones_like(x, dtype=torch.bool) for token_id in self.ignored_token_ids: mask = mask & (x != token_id) - # If the mask is all False (all tokens are ignored), return zero for both losses. if not mask.any(): return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) - # 2. Part 1: Cross-Entropy Loss (loss_ce) - # Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy. logits_for_ce = logits.permute(0, 2, 1) - - # Calculate per-element loss without reduction. per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') - - # Apply the mask and compute the mean of valid elements. loss_ce = per_element_ce[mask].mean() - # 3. Part 2: Survival Loss (loss_survival) - # Calculate event intensity (lambda) as the sum of exponentiated logits. intensity = torch.sum(torch.exp(logits), dim=2) - - # Calculate per-element survival loss (negative log-likelihood of exponential dist). - # We add a small epsilon for numerical stability with the log. per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) - - # Apply the mask and compute the mean of valid elements. loss_survival = per_element_survival[mask].mean() return loss_ce, loss_survival