feat: Add covariate-aware model and piecewise encoder
Introduce PiecewiseLinearEncoder for continuous variable encoding. Add CovariateAwareGPT2 to extend TimeAwareGPT2 with static and time-varying covariate processing. The model combines piecewise linear and sinusoidal encodings for covariates and integrates them via concatenation before a final MLP head. Reorganize models.py for better logical structure.
This commit is contained in:
256
models.py
256
models.py
@@ -3,6 +3,10 @@ import torch.nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 1. Component Modules (Building Blocks)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
""" an unassuming Transformer block """
|
""" an unassuming Transformer block """
|
||||||
|
|
||||||
@@ -58,14 +62,8 @@ class AgeSinusoidalEncoding(nn.Module):
|
|||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
# Pre-calculate the divisor term for the sinusoidal formula.
|
# Pre-calculate the divisor term for the sinusoidal formula.
|
||||||
# The formula for the divisor is 10000^(2i/D), where D is the
|
|
||||||
# embedding_dim and i is the index for each pair of dimensions.
|
|
||||||
# i ranges from 0 to D/2 - 1.
|
|
||||||
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||||
divisor = torch.pow(10000, i / self.embedding_dim)
|
divisor = torch.pow(10000, i / self.embedding_dim)
|
||||||
|
|
||||||
# Register the divisor as a non-trainable buffer. This ensures it is
|
|
||||||
# moved to the correct device (e.g., GPU) along with the model.
|
|
||||||
self.register_buffer('divisor', divisor)
|
self.register_buffer('divisor', divisor)
|
||||||
|
|
||||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -80,28 +78,100 @@ class AgeSinusoidalEncoding(nn.Module):
|
|||||||
torch.Tensor: The encoded age tensor of shape
|
torch.Tensor: The encoded age tensor of shape
|
||||||
(batch_size, sequence_length, embedding_dim).
|
(batch_size, sequence_length, embedding_dim).
|
||||||
"""
|
"""
|
||||||
# 1. Unit Conversion: Convert age from days to years.
|
|
||||||
# We use 365.25 to account for leap years.
|
|
||||||
t_years = t / 365.25
|
t_years = t / 365.25
|
||||||
|
|
||||||
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
|
|
||||||
# The shapes are broadcast to (B, L, D/2).
|
|
||||||
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
|
|
||||||
# Divisor: (D/2) -> viewed as (1, 1, D/2)
|
|
||||||
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
||||||
|
|
||||||
# 3. Sinusoidal Application: Create the final output tensor.
|
|
||||||
# Initialize an empty tensor to store the embeddings.
|
|
||||||
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
||||||
|
|
||||||
# Assign cosine of the arguments to the even indices.
|
|
||||||
output[:, :, 0::2] = torch.cos(args)
|
output[:, :, 0::2] = torch.cos(args)
|
||||||
|
|
||||||
# Assign sine of the arguments to the odd indices.
|
|
||||||
output[:, :, 1::2] = torch.sin(args)
|
output[:, :, 1::2] = torch.sin(args)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class PiecewiseLinearEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Encodes continuous variables using piecewise linear encoding.
|
||||||
|
|
||||||
|
This module defines bins based on standard normal distribution quantiles,
|
||||||
|
encodes an input by finding its bin, and calculates its position as a
|
||||||
|
linear interpolation between boundaries. The result is projected to the
|
||||||
|
final embedding dimension by a shared linear layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_bins: int, embedding_dim: int):
|
||||||
|
"""
|
||||||
|
Initializes the PiecewiseLinearEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_bins (int): The number of bins for the encoding.
|
||||||
|
embedding_dim (int): The dimensionality of the output embedding (D).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if num_bins <= 0:
|
||||||
|
raise ValueError("num_bins must be a positive integer.")
|
||||||
|
self.num_bins = num_bins
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
if num_bins > 1:
|
||||||
|
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
|
||||||
|
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||||
|
boundaries = normal_dist.icdf(quantiles)
|
||||||
|
else:
|
||||||
|
boundaries = torch.tensor([])
|
||||||
|
|
||||||
|
boundaries = torch.cat([
|
||||||
|
torch.tensor([float('-inf')]),
|
||||||
|
boundaries,
|
||||||
|
torch.tensor([float('inf')])
|
||||||
|
])
|
||||||
|
self.register_buffer('boundaries', boundaries)
|
||||||
|
|
||||||
|
self.linear = nn.Linear(num_bins, embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the piecewise linear encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (*, N), where * is any
|
||||||
|
number of batch dimensions and N is the number of continuous
|
||||||
|
features. Assumed to be pre-scaled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor of shape (*, N, D).
|
||||||
|
"""
|
||||||
|
original_shape = x.shape
|
||||||
|
x = x.reshape(-1, original_shape[-1])
|
||||||
|
|
||||||
|
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
|
||||||
|
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
|
||||||
|
|
||||||
|
lower_bounds = self.boundaries[bin_indices]
|
||||||
|
upper_bounds = self.boundaries[bin_indices + 1]
|
||||||
|
delta = upper_bounds - lower_bounds + 1e-8
|
||||||
|
|
||||||
|
weight_upper = (x - lower_bounds) / delta
|
||||||
|
weight_lower = 1.0 - weight_upper
|
||||||
|
|
||||||
|
is_first_bin = (bin_indices == 0)
|
||||||
|
is_last_bin = (bin_indices == self.num_bins - 1)
|
||||||
|
|
||||||
|
weight_lower[is_first_bin] = 1.0
|
||||||
|
weight_upper[is_first_bin] = 0.0
|
||||||
|
weight_lower[is_last_bin] = 0.0
|
||||||
|
weight_upper[is_last_bin] = 1.0
|
||||||
|
|
||||||
|
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
|
||||||
|
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
|
||||||
|
|
||||||
|
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
|
||||||
|
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
|
||||||
|
|
||||||
|
encoded = encoded.view(*original_shape, self.num_bins)
|
||||||
|
output = self.linear(encoded)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 2. Main Model Architectures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
class TimeAwareGPT2(nn.Module):
|
class TimeAwareGPT2(nn.Module):
|
||||||
"""
|
"""
|
||||||
A time-aware GPT-2 model with custom temporal features.
|
A time-aware GPT-2 model with custom temporal features.
|
||||||
@@ -111,18 +181,12 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_pdrop = token_pdrop
|
self.token_pdrop = token_pdrop
|
||||||
|
|
||||||
# Token and positional embeddings
|
|
||||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||||
self.drop = nn.Dropout(pdrop)
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
|
||||||
# Final layer norm and linear head
|
|
||||||
self.ln_f = nn.LayerNorm(n_embd)
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
self.n_embd = n_embd
|
self.n_embd = n_embd
|
||||||
|
|
||||||
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -138,53 +202,30 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, L = event_seq.size()
|
B, L = event_seq.size()
|
||||||
|
|
||||||
# 1. Get token embeddings
|
|
||||||
token_embeddings = self.wte(event_seq)
|
token_embeddings = self.wte(event_seq)
|
||||||
|
|
||||||
# 2. Apply token dropout (only during training)
|
|
||||||
if self.training and self.token_pdrop > 0:
|
if self.training and self.token_pdrop > 0:
|
||||||
# Create a mask to randomly zero out entire token embedding vectors
|
|
||||||
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||||
token_embeddings[drop_mask] = 0.0
|
token_embeddings[drop_mask] = 0.0
|
||||||
|
|
||||||
# 3. Get positional embeddings from time sequence
|
|
||||||
pos_embeddings = self.age_encoder(time_seq.float())
|
pos_embeddings = self.age_encoder(time_seq.float())
|
||||||
|
|
||||||
# 4. Combine embeddings and apply dropout
|
|
||||||
x = self.drop(token_embeddings + pos_embeddings)
|
x = self.drop(token_embeddings + pos_embeddings)
|
||||||
|
|
||||||
# 5. Generate attention mask
|
t_i = time_seq.unsqueeze(-1)
|
||||||
# The attention mask combines two conditions:
|
t_j = time_seq.unsqueeze(1)
|
||||||
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
|
|
||||||
# b) Padding mask: Do not attend to positions where the event token is 0.
|
|
||||||
|
|
||||||
# a) Time-based causal mask
|
|
||||||
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
|
||||||
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
|
||||||
time_mask = (t_j < t_i)
|
time_mask = (t_j < t_i)
|
||||||
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
||||||
# b) Padding mask (prevents attending to key positions that are padding)
|
|
||||||
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
|
||||||
|
|
||||||
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
|
||||||
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
|
||||||
combined_mask = time_mask & padding_mask
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
# Forcibly allow a non-padding token to attend to itself if it cannot attend to any other token.
|
|
||||||
# This prevents NaN issues in the attention mechanism for the first token in a sequence.
|
|
||||||
is_row_all_zero = ~combined_mask.any(dim=-1)
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||||
is_not_padding = (event_seq != 0)
|
is_not_padding = (event_seq != 0)
|
||||||
force_self_attention = is_row_all_zero & is_not_padding
|
force_self_attention = is_row_all_zero & is_not_padding
|
||||||
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||||
|
|
||||||
# 6. Pass through transformer blocks
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, custom_mask=combined_mask)
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
# 7. Final layer norm and projection to vocab size
|
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
logits = self.head(x)
|
logits = self.head(x)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def get_num_params(self) -> float:
|
def get_num_params(self) -> float:
|
||||||
@@ -193,6 +234,99 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
class CovariateAwareGPT2(nn.Module):
|
||||||
|
"""
|
||||||
|
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
|
||||||
|
pdrop: float, token_pdrop: float, num_bins: int):
|
||||||
|
"""
|
||||||
|
Initializes the CovariateAwareGPT2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (int): Size of the event vocabulary.
|
||||||
|
n_embd (int): Embedding dimensionality.
|
||||||
|
n_layer (int): Number of transformer layers.
|
||||||
|
n_head (int): Number of attention heads.
|
||||||
|
pdrop (float): Dropout probability for layers.
|
||||||
|
token_pdrop (float): Dropout probability for input token embeddings.
|
||||||
|
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.token_pdrop = token_pdrop
|
||||||
|
|
||||||
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||||
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||||
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(2 * n_embd)
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Linear(2 * n_embd, n_embd),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(n_embd, vocab_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the CovariateAwareGPT2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Event sequence tensor of shape (B, L).
|
||||||
|
t (torch.Tensor): Time sequence tensor of shape (B, L).
|
||||||
|
cov (torch.Tensor): Covariate tensor of shape (B, N).
|
||||||
|
cov_t (torch.Tensor): Covariate time tensor of shape (B).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Logits of shape (B, L, vocab_size).
|
||||||
|
"""
|
||||||
|
B, L = x.size()
|
||||||
|
|
||||||
|
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
|
||||||
|
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
|
||||||
|
cov_embed = cov_encoded + cov_t_encoded
|
||||||
|
|
||||||
|
token_embeddings = self.wte(x)
|
||||||
|
if self.training and self.token_pdrop > 0:
|
||||||
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||||
|
token_embeddings[drop_mask] = 0.0
|
||||||
|
|
||||||
|
pos_embeddings = self.age_encoder(t.float())
|
||||||
|
seq_embed = self.drop(token_embeddings + pos_embeddings)
|
||||||
|
|
||||||
|
t_i = t.unsqueeze(-1)
|
||||||
|
t_j = t.unsqueeze(1)
|
||||||
|
time_mask = (t_j < t_i)
|
||||||
|
padding_mask = (x != 0).unsqueeze(1)
|
||||||
|
combined_mask = time_mask & padding_mask
|
||||||
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||||
|
is_not_padding = (x != 0)
|
||||||
|
force_self_attention = is_row_all_zero & is_not_padding
|
||||||
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||||
|
|
||||||
|
block_output = seq_embed
|
||||||
|
for block in self.blocks:
|
||||||
|
block_output = block(block_output, custom_mask=combined_mask)
|
||||||
|
|
||||||
|
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
|
||||||
|
|
||||||
|
final_output = self.ln_f(integrated_embed)
|
||||||
|
logits = self.head(final_output)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_num_params(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the number of trainable parameters in the model in millions.
|
||||||
|
"""
|
||||||
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 3. Loss Function
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Computes a two-part loss: a standard cross-entropy loss for event type
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||||
@@ -222,35 +356,19 @@ class CombinedLoss(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
||||||
"""
|
"""
|
||||||
# 1. Create a mask to filter out ignored token IDs from loss calculation.
|
|
||||||
# An element is True if the corresponding label in x is NOT in the ignored list.
|
|
||||||
mask = torch.ones_like(x, dtype=torch.bool)
|
mask = torch.ones_like(x, dtype=torch.bool)
|
||||||
for token_id in self.ignored_token_ids:
|
for token_id in self.ignored_token_ids:
|
||||||
mask = mask & (x != token_id)
|
mask = mask & (x != token_id)
|
||||||
|
|
||||||
# If the mask is all False (all tokens are ignored), return zero for both losses.
|
|
||||||
if not mask.any():
|
if not mask.any():
|
||||||
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
||||||
|
|
||||||
# 2. Part 1: Cross-Entropy Loss (loss_ce)
|
|
||||||
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
|
|
||||||
logits_for_ce = logits.permute(0, 2, 1)
|
logits_for_ce = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
# Calculate per-element loss without reduction.
|
|
||||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||||
|
|
||||||
# Apply the mask and compute the mean of valid elements.
|
|
||||||
loss_ce = per_element_ce[mask].mean()
|
loss_ce = per_element_ce[mask].mean()
|
||||||
|
|
||||||
# 3. Part 2: Survival Loss (loss_survival)
|
|
||||||
# Calculate event intensity (lambda) as the sum of exponentiated logits.
|
|
||||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
intensity = torch.sum(torch.exp(logits), dim=2)
|
||||||
|
|
||||||
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
|
|
||||||
# We add a small epsilon for numerical stability with the log.
|
|
||||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||||||
|
|
||||||
# Apply the mask and compute the mean of valid elements.
|
|
||||||
loss_survival = per_element_survival[mask].mean()
|
loss_survival = per_element_survival[mask].mean()
|
||||||
|
|
||||||
return loss_ce, loss_survival
|
return loss_ce, loss_survival
|
||||||
|
Reference in New Issue
Block a user