Files
DeepHealth/models.py

428 lines
16 KiB
Python

import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
# =============================================================================
# 1. Component Modules (Building Blocks)
# =============================================================================
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, n_embd: int, n_head: int, pdrop: float):
super().__init__()
self.n_head = n_head
self.ln_1 = nn.LayerNorm(n_embd)
self.attn = nn.MultiheadAttention(n_embd, n_head, dropout=pdrop, batch_first=True)
self.ln_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(n_embd, 4 * n_embd),
c_proj = nn.Linear(4 * n_embd, n_embd),
act = nn.GELU(),
dropout = nn.Dropout(pdrop),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
self.resid_dropout = nn.Dropout(pdrop)
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
normed_x = self.ln_1(x)
attn_mask = ~custom_mask
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0)
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
x = x + self.resid_dropout(attn_output)
x = x + self.mlpf(self.ln_2(x))
return x
class AgeSinusoidalEncoding(nn.Module):
"""
Encodes age using sinusoidal functions, similar to positional encodings
in Transformers. This module creates a fixed-size embedding for an age
value given in days.
"""
def __init__(self, embedding_dim: int):
"""
Initializes the AgeSinusoidalEncoding module.
Args:
embedding_dim (int): The dimensionality of the output embedding.
Must be an even number.
Raises:
ValueError: If embedding_dim is not an even number.
"""
super().__init__()
if embedding_dim % 2 != 0:
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
self.embedding_dim = embedding_dim
# Pre-calculate the divisor term for the sinusoidal formula.
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.embedding_dim)
self.register_buffer('divisor', divisor)
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the AgeSinusoidalEncoding.
Args:
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
with dtype=torch.float32, representing age in days.
Returns:
torch.Tensor: The encoded age tensor of shape
(batch_size, sequence_length, embedding_dim).
"""
t_years = t / 365.25
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
output[:, :, 0::2] = torch.cos(args)
output[:, :, 1::2] = torch.sin(args)
return output
class PiecewiseLinearEncoder(nn.Module):
"""
Encodes continuous variables using piecewise linear encoding.
This module defines bins based on standard normal distribution quantiles,
encodes an input by finding its bin, and calculates its position as a
linear interpolation between boundaries. The result is projected to the
final embedding dimension by a shared linear layer.
"""
def __init__(self, num_bins: int, embedding_dim: int):
"""
Initializes the PiecewiseLinearEncoder module.
Args:
num_bins (int): The number of bins for the encoding.
embedding_dim (int): The dimensionality of the output embedding (D).
"""
super().__init__()
if num_bins <= 0:
raise ValueError("num_bins must be a positive integer.")
self.num_bins = num_bins
self.embedding_dim = embedding_dim
if num_bins > 1:
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
normal_dist = torch.distributions.normal.Normal(0, 1)
boundaries = normal_dist.icdf(quantiles)
else:
boundaries = torch.tensor([])
boundaries = torch.cat([
torch.tensor([float('-inf')]),
boundaries,
torch.tensor([float('inf')])
])
self.register_buffer('boundaries', boundaries)
self.linear = nn.Linear(num_bins, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the piecewise linear encoding.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where * is any
number of batch dimensions and N is the number of continuous
features. Assumed to be pre-scaled.
Returns:
torch.Tensor: Encoded tensor of shape (*, N, D).
"""
original_shape = x.shape
x = x.reshape(-1, original_shape[-1])
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
lower_bounds = self.boundaries[bin_indices]
upper_bounds = self.boundaries[bin_indices + 1]
delta = upper_bounds - lower_bounds + 1e-8
weight_upper = (x - lower_bounds) / delta
weight_lower = 1.0 - weight_upper
is_first_bin = (bin_indices == 0)
is_last_bin = (bin_indices == self.num_bins - 1)
weight_lower[is_first_bin] = 1.0
weight_upper[is_first_bin] = 0.0
weight_lower[is_last_bin] = 0.0
weight_upper[is_last_bin] = 1.0
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
encoded = encoded.view(*original_shape, self.num_bins)
output = self.linear(encoded)
return output
# =============================================================================
# 2. Main Model Architectures
# =============================================================================
class TimeAwareGPT2(nn.Module):
"""
A time-aware GPT-2 model with custom temporal features.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
super().__init__()
self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.n_embd = n_embd
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the TimeAwareGPT2 model.
Args:
event_seq (torch.Tensor): Token indices of shape (B, L).
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = event_seq.size()
token_embeddings = self.wte(event_seq)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(time_seq.float())
x = self.drop(token_embeddings + pos_embeddings)
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (event_seq != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks:
x = block(x, custom_mask=combined_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
"""
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
return x, t, final_logits
class CovariateAwareGPT2(nn.Module):
"""
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
pdrop: float, token_pdrop: float, num_bins: int):
"""
Initializes the CovariateAwareGPT2 model.
Args:
vocab_size (int): Size of the event vocabulary.
n_embd (int): Embedding dimensionality.
n_layer (int): Number of transformer layers.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability for layers.
token_pdrop (float): Dropout probability for input token embeddings.
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
"""
super().__init__()
self.token_pdrop = token_pdrop
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.n_embd = n_embd
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
self.ln_f = nn.LayerNorm(2 * n_embd)
self.head = nn.Sequential(
nn.Linear(2 * n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, vocab_size)
)
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CovariateAwareGPT2 model.
Args:
x (torch.Tensor): Event sequence tensor of shape (B, L).
t (torch.Tensor): Time sequence tensor of shape (B, L).
cov (torch.Tensor): Covariate tensor of shape (B, N).
cov_t (torch.Tensor): Covariate time tensor of shape (B).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = x.size()
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
cov_embed = cov_encoded + cov_t_encoded
token_embeddings = self.wte(x)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(t.float())
seq_embed = self.drop(token_embeddings + pos_embeddings)
t_i = t.unsqueeze(-1)
t_j = t.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (x != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (x != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
block_output = seq_embed
for block in self.blocks:
block_output = block(block_output, custom_mask=combined_mask)
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
final_output = self.ln_f(integrated_embed)
logits = self.head(final_output)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
# =============================================================================
# 3. Loss Function
# =============================================================================
class CombinedLoss(nn.Module):
"""
Computes a two-part loss: a standard cross-entropy loss for event type
prediction and a survival analysis loss for event timing.
"""
def __init__(self, ignored_token_ids: list[int]):
"""
Initializes the CombinedLoss module.
Args:
ignored_token_ids (list[int]): A list of event type IDs to be
excluded from all loss calculations.
"""
super().__init__()
self.ignored_token_ids = ignored_token_ids
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates the combined cross-entropy and survival loss.
Args:
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
x (torch.Tensor): Ground-truth event labels of shape (B, L).
t (torch.Tensor): True time duration for each event, shape (B, L).
Returns:
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
"""
mask = torch.ones_like(x, dtype=torch.bool)
for token_id in self.ignored_token_ids:
mask = mask & (x != token_id)
if not mask.any():
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
logits_for_ce = logits.permute(0, 2, 1)
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
loss_ce = per_element_ce[mask].mean()
intensity = torch.sum(torch.exp(logits), dim=2)
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
loss_survival = per_element_survival[mask].mean()
return loss_ce, loss_survival