285 lines
11 KiB
Python
285 lines
11 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import functional as F
|
||
|
from typing import Tuple
|
||
|
import math
|
||
|
|
||
|
class CausalSelfAttention(nn.Module):
|
||
|
"""
|
||
|
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||
|
super().__init__()
|
||
|
assert n_embd % n_head == 0
|
||
|
# key, query, value projections for all heads
|
||
|
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
||
|
# output projection
|
||
|
self.c_proj = nn.Linear(n_embd, n_embd)
|
||
|
# regularization
|
||
|
self.attn_dropout = nn.Dropout(pdrop)
|
||
|
self.resid_dropout = nn.Dropout(pdrop)
|
||
|
self.n_head = n_head
|
||
|
self.n_embd = n_embd
|
||
|
|
||
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||
|
B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||
|
|
||
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||
|
k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||
|
q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||
|
v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||
|
|
||
|
# causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
|
||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||
|
|
||
|
# Apply the time-based causal mask
|
||
|
att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf'))
|
||
|
|
||
|
att = F.softmax(att, dim=-1)
|
||
|
att = self.attn_dropout(att)
|
||
|
y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
|
||
|
y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side
|
||
|
|
||
|
# output projection
|
||
|
y = self.resid_dropout(self.c_proj(y))
|
||
|
return y
|
||
|
|
||
|
class Block(nn.Module):
|
||
|
""" an unassuming Transformer block """
|
||
|
|
||
|
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||
|
super().__init__()
|
||
|
self.ln_1 = nn.LayerNorm(n_embd)
|
||
|
self.attn = CausalSelfAttention(n_embd, n_head, pdrop)
|
||
|
self.ln_2 = nn.LayerNorm(n_embd)
|
||
|
self.mlp = nn.ModuleDict(dict(
|
||
|
c_fc = nn.Linear(n_embd, 4 * n_embd),
|
||
|
c_proj = nn.Linear(4 * n_embd, n_embd),
|
||
|
act = nn.GELU(),
|
||
|
dropout = nn.Dropout(pdrop),
|
||
|
))
|
||
|
m = self.mlp
|
||
|
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
|
||
|
|
||
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||
|
x = x + self.attn(self.ln_1(x), custom_mask=custom_mask)
|
||
|
x = x + self.mlpf(self.ln_2(x))
|
||
|
return x
|
||
|
|
||
|
class AgeSinusoidalEncoding(nn.Module):
|
||
|
"""
|
||
|
Encodes age using sinusoidal functions, similar to positional encodings
|
||
|
in Transformers. This module creates a fixed-size embedding for an age
|
||
|
value given in days.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, embedding_dim: int):
|
||
|
"""
|
||
|
Initializes the AgeSinusoidalEncoding module.
|
||
|
|
||
|
Args:
|
||
|
embedding_dim (int): The dimensionality of the output embedding.
|
||
|
Must be an even number.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If embedding_dim is not an even number.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
if embedding_dim % 2 != 0:
|
||
|
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
|
||
|
|
||
|
self.embedding_dim = embedding_dim
|
||
|
|
||
|
# Pre-calculate the divisor term for the sinusoidal formula.
|
||
|
# The formula for the divisor is 10000^(2i/D), where D is the
|
||
|
# embedding_dim and i is the index for each pair of dimensions.
|
||
|
# i ranges from 0 to D/2 - 1.
|
||
|
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||
|
divisor = torch.pow(10000, i / self.embedding_dim)
|
||
|
|
||
|
# Register the divisor as a non-trainable buffer. This ensures it is
|
||
|
# moved to the correct device (e.g., GPU) along with the model.
|
||
|
self.register_buffer('divisor', divisor)
|
||
|
|
||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Forward pass for the AgeSinusoidalEncoding.
|
||
|
|
||
|
Args:
|
||
|
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
|
||
|
with dtype=torch.float32, representing age in days.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: The encoded age tensor of shape
|
||
|
(batch_size, sequence_length, embedding_dim).
|
||
|
"""
|
||
|
# 1. Unit Conversion: Convert age from days to years.
|
||
|
# We use 365.25 to account for leap years.
|
||
|
t_years = t / 365.25
|
||
|
|
||
|
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
|
||
|
# The shapes are broadcast to (B, L, D/2).
|
||
|
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
|
||
|
# Divisor: (D/2) -> viewed as (1, 1, D/2)
|
||
|
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
||
|
|
||
|
# 3. Sinusoidal Application: Create the final output tensor.
|
||
|
# Initialize an empty tensor to store the embeddings.
|
||
|
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
||
|
|
||
|
# Assign cosine of the arguments to the even indices.
|
||
|
output[:, :, 0::2] = torch.cos(args)
|
||
|
|
||
|
# Assign sine of the arguments to the odd indices.
|
||
|
output[:, :, 1::2] = torch.sin(args)
|
||
|
|
||
|
return output
|
||
|
|
||
|
class TimeAwareGPT2(nn.Module):
|
||
|
"""
|
||
|
A time-aware GPT-2 model with custom temporal features.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
||
|
super().__init__()
|
||
|
self.token_pdrop = token_pdrop
|
||
|
|
||
|
# Token and positional embeddings
|
||
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||
|
self.drop = nn.Dropout(pdrop)
|
||
|
|
||
|
# Transformer blocks
|
||
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||
|
|
||
|
# Final layer norm and linear head
|
||
|
self.ln_f = nn.LayerNorm(n_embd)
|
||
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||
|
|
||
|
self.n_embd = n_embd
|
||
|
|
||
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Forward pass for the TimeAwareGPT2 model.
|
||
|
|
||
|
Args:
|
||
|
event_seq (torch.Tensor): Token indices of shape (B, L).
|
||
|
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: Logits of shape (B, L, vocab_size).
|
||
|
"""
|
||
|
B, L = event_seq.size()
|
||
|
|
||
|
# 1. Get token embeddings
|
||
|
token_embeddings = self.wte(event_seq)
|
||
|
|
||
|
# 2. Apply token dropout (only during training)
|
||
|
if self.training and self.token_pdrop > 0:
|
||
|
# Create a mask to randomly zero out entire token embedding vectors
|
||
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||
|
token_embeddings[drop_mask] = 0.0
|
||
|
|
||
|
# 3. Get positional embeddings from time sequence
|
||
|
pos_embeddings = self.age_encoder(time_seq.float())
|
||
|
|
||
|
# 4. Combine embeddings and apply dropout
|
||
|
x = self.drop(token_embeddings + pos_embeddings)
|
||
|
|
||
|
# 5. Generate attention mask
|
||
|
# The attention mask combines two conditions:
|
||
|
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
|
||
|
# b) Padding mask: Do not attend to positions where the event token is 0.
|
||
|
|
||
|
# a) Time-based causal mask
|
||
|
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
||
|
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
||
|
time_mask = (t_j < t_i)
|
||
|
|
||
|
# b) Padding mask (prevents attending to key positions that are padding)
|
||
|
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
||
|
|
||
|
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
||
|
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
||
|
combined_mask = time_mask & padding_mask
|
||
|
|
||
|
# 6. Pass through transformer blocks
|
||
|
for block in self.blocks:
|
||
|
x = block(x, custom_mask=combined_mask)
|
||
|
|
||
|
# 7. Final layer norm and projection to vocab size
|
||
|
x = self.ln_f(x)
|
||
|
logits = self.head(x)
|
||
|
|
||
|
return logits
|
||
|
|
||
|
def get_num_params(self) -> float:
|
||
|
"""
|
||
|
Returns the number of trainable parameters in the model in millions.
|
||
|
"""
|
||
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||
|
|
||
|
class CombinedLoss(nn.Module):
|
||
|
"""
|
||
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
||
|
prediction and a survival analysis loss for event timing.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, ignored_token_ids: list[int]):
|
||
|
"""
|
||
|
Initializes the CombinedLoss module.
|
||
|
|
||
|
Args:
|
||
|
ignored_token_ids (list[int]): A list of event type IDs to be
|
||
|
excluded from all loss calculations.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.ignored_token_ids = ignored_token_ids
|
||
|
|
||
|
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Calculates the combined cross-entropy and survival loss.
|
||
|
|
||
|
Args:
|
||
|
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
|
||
|
x (torch.Tensor): Ground-truth event labels of shape (B, L).
|
||
|
t (torch.Tensor): True time duration for each event, shape (B, L).
|
||
|
|
||
|
Returns:
|
||
|
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
||
|
"""
|
||
|
# 1. Create a mask to filter out ignored token IDs from loss calculation.
|
||
|
# An element is True if the corresponding label in x is NOT in the ignored list.
|
||
|
mask = torch.ones_like(x, dtype=torch.bool)
|
||
|
for token_id in self.ignored_token_ids:
|
||
|
mask = mask & (x != token_id)
|
||
|
|
||
|
# If the mask is all False (all tokens are ignored), return zero for both losses.
|
||
|
if not mask.any():
|
||
|
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
||
|
|
||
|
# 2. Part 1: Cross-Entropy Loss (loss_ce)
|
||
|
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
|
||
|
logits_for_ce = logits.permute(0, 2, 1)
|
||
|
|
||
|
# Calculate per-element loss without reduction.
|
||
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||
|
|
||
|
# Apply the mask and compute the mean of valid elements.
|
||
|
loss_ce = per_element_ce[mask].mean()
|
||
|
|
||
|
# 3. Part 2: Survival Loss (loss_survival)
|
||
|
# Calculate event intensity (lambda) as the sum of exponentiated logits.
|
||
|
intensity = torch.sum(torch.exp(logits), dim=2)
|
||
|
|
||
|
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
|
||
|
# We add a small epsilon for numerical stability with the log.
|
||
|
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||
|
|
||
|
# Apply the mask and compute the mean of valid elements.
|
||
|
loss_survival = per_element_survival[mask].mean()
|
||
|
|
||
|
return loss_ce, loss_survival
|