import torch import torch.nn as nn from torch.nn import functional as F from typing import Tuple, Optional import math # ============================================================================= # 1. Component Modules (Building Blocks) # ============================================================================= class CausalConv1d(nn.Module): def __init__(self, channels, kernel_size, groups=1): super().__init__() self.pad = kernel_size - 1 self.conv = nn.Conv1d( channels, channels, kernel_size, padding=0, groups=groups ) def forward(self, x): # x: (B, C, L) x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality x = x.contiguous() return self.conv(x) class DepthwiseSeparableCausalConvBlock(nn.Module): def __init__(self, d_model, kernel_size=5, dropout=0.1): super().__init__() self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise self.act = nn.GELU() self.ln = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): # x: (B, L, D) y = x.transpose(1, 2).contiguous() # (B, D, L) y = self.dw(y) # (B, D, L) y = self.pw(y.contiguous()) # (B, D, L) y = y.transpose(1, 2).contiguous() # (B, L, D) y = self.act(y) y = self.dropout(y) return self.ln(x + y) # residual connection + layer norm (LN) class TimeFeatureProjector(nn.Module): """ Projects scalar time t and its increment Δt into d_model dimensions. Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features). """ def __init__(self, d_model, fourier_dim=32, dt_clip=1e6): super().__init__() self.dt_clip = dt_clip self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D # Predefine a set of logarithmically spaced frequencies (tune for your time units if needed) k = fourier_dim // 2 freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2 self.register_buffer("freqs", freqs, persistent=False) self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features self.ln = nn.LayerNorm(d_model) def forward(self, t): # t: (B, L) continuous timestamps/steps # compute increments Δt and stabilize dt = t - F.pad(t, (1, 0), value=0.)[:, :-1] dt = torch.clamp(dt, min=0.) # ensure non-negative # normalize/stabilize with log compression t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip)) dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip)) scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2) scal_feat = self.scalar_proj(scal) # (B, L, D) # Fixed-frequency sin/cos to capture absolute/relative periodicity # If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant) # (B, L, K) wt = t[..., None] * self.freqs sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K) fourier_feat = self.fourier_proj(sincos) # (B, L, D) # gated fusion + layer norm h = scal_feat + torch.tanh(self.gate) * fourier_feat return self.ln(h) # (B, L, D) class Block(nn.Module): """ an unassuming Transformer block """ def __init__(self, n_embd: int, n_head: int, pdrop: float): super().__init__() self.n_head = n_head self.ln_1 = nn.LayerNorm(n_embd) self.attn = nn.MultiheadAttention(n_embd, n_head, dropout=pdrop, batch_first=True) self.ln_2 = nn.LayerNorm(n_embd) self.mlp = nn.ModuleDict(dict( c_fc = nn.Linear(n_embd, 4 * n_embd), c_proj = nn.Linear(4 * n_embd, n_embd), act = nn.GELU(), dropout = nn.Dropout(pdrop), )) m = self.mlp self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward self.resid_dropout = nn.Dropout(pdrop) def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: normed_x = self.ln_1(x) # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked. mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9) attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) x = x + self.resid_dropout(attn_output) x = x + self.mlpf(self.ln_2(x)) return x class AgeSinusoidalEncoding(nn.Module): """ Encodes age using sinusoidal functions, similar to positional encodings in Transformers. This module creates a fixed-size embedding for an age value given in days. """ def __init__(self, embedding_dim: int): """ Initializes the AgeSinusoidalEncoding module. Args: embedding_dim (int): The dimensionality of the output embedding. Must be an even number. Raises: ValueError: If embedding_dim is not an even number. """ super().__init__() if embedding_dim % 2 != 0: raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}") self.embedding_dim = embedding_dim # Pre-calculate the divisor term for the sinusoidal formula. i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) divisor = torch.pow(10000, i / self.embedding_dim) self.register_buffer('divisor', divisor) def forward(self, t: torch.Tensor) -> torch.Tensor: """ Forward pass for the AgeSinusoidalEncoding. Args: t (torch.Tensor): A tensor of shape (batch_size, sequence_length) with dtype=torch.float32, representing age in days. Returns: torch.Tensor: The encoded age tensor of shape (batch_size, sequence_length, embedding_dim). """ t_years = t / 365.25 args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) output[:, :, 0::2] = torch.cos(args) output[:, :, 1::2] = torch.sin(args) return output class LearnableAgeEncoding(nn.Module): """Combines fixed sinusoidal age encodings with a learnable MLP projection.""" def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0): super().__init__() self.base_dim = base_dim self.final_dim = final_dim or base_dim hidden_dim = hidden_dim or base_dim if hidden_dim <= 0: raise ValueError("hidden_dim must be a positive integer.") if self.final_dim <= 0: raise ValueError("final_dim must be a positive integer.") self.sinusoidal = AgeSinusoidalEncoding(base_dim) mlp_layers = [ nn.Linear(base_dim, hidden_dim), nn.GELU(), ] if dropout > 0.0: mlp_layers.append(nn.Dropout(dropout)) mlp_layers.append(nn.Linear(hidden_dim, self.final_dim)) self.mlp = nn.Sequential(*mlp_layers) def forward(self, t: torch.Tensor) -> torch.Tensor: sin_embed = self.sinusoidal(t) flat_embed = sin_embed.reshape(-1, self.base_dim) projected = self.mlp(flat_embed) return projected.reshape(*sin_embed.shape[:-1], self.final_dim) class PiecewiseLinearEncoder(nn.Module): """ Encodes continuous variables using piecewise linear encoding. This module defines bins based on standard normal distribution quantiles, encodes an input by finding its bin, and calculates its position as a linear interpolation between boundaries. The result is projected to the final embedding dimension by a shared linear layer. """ def __init__(self, num_bins: int, embedding_dim: int): """ Initializes the PiecewiseLinearEncoder module. Args: num_bins (int): The number of bins for the encoding. embedding_dim (int): The dimensionality of the output embedding (D). """ super().__init__() if num_bins <= 0: raise ValueError("num_bins must be a positive integer.") self.num_bins = num_bins self.embedding_dim = embedding_dim if num_bins > 1: quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1) normal_dist = torch.distributions.normal.Normal(0, 1) boundaries = normal_dist.icdf(quantiles) else: boundaries = torch.tensor([]) boundaries = torch.cat([ torch.tensor([float('-inf')]), boundaries, torch.tensor([float('inf')]) ]) self.register_buffer('boundaries', boundaries) self.linear = nn.Linear(num_bins, embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the piecewise linear encoding. Args: x (torch.Tensor): Input tensor of shape (*, N), where * is any number of batch dimensions and N is the number of continuous features. Assumed to be pre-scaled. Returns: torch.Tensor: Encoded tensor of shape (*, N, D). """ original_shape = x.shape x = x.reshape(-1, original_shape[-1]) bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1 bin_indices = bin_indices.clamp(0, self.num_bins - 1) lower_bounds = self.boundaries[bin_indices] upper_bounds = self.boundaries[bin_indices + 1] delta = upper_bounds - lower_bounds + 1e-8 weight_upper = (x - lower_bounds) / delta weight_lower = 1.0 - weight_upper is_first_bin = (bin_indices == 0) is_last_bin = (bin_indices == self.num_bins - 1) weight_lower[is_first_bin] = 1.0 weight_upper[is_first_bin] = 0.0 weight_lower[is_last_bin] = 0.0 weight_upper[is_last_bin] = 1.0 encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype) encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1)) upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1) encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1)) encoded = encoded.view(*original_shape, self.num_bins) output = self.linear(encoded) return output class TemporalConvEncoder(nn.Module): """ Inputs: x: (B, L) - event/token ids t: (B, L) - timestamps (real-valued) or step indices Output: h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds """ def __init__( self, vocab_size: int, d_model: int = 768, n_layers: int = 2, kernel_size: int = 5, dropout: float = 0.1, fourier_dim: int = 32, pad_id: int = 0 ): super().__init__() self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim) self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features self.ln_in = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) blocks = [] for _ in range(n_layers): blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout)) self.blocks = nn.ModuleList(blocks) def forward(self, x, t, attention_mask=None): """ attention_mask: (B, L) 1=keep, 0=padding """ tok = self.token_emb(x) # (B, L, D) tim = self.time_proj(t) # (B, L, D) h = torch.cat([tok, tim], dim=-1) # (B, L, 2D) h = self.fuse(h) # (B, L, D) h = self.ln_in(h) h = self.dropout(h) # Optional: zero-out padding positions before convolutions to avoid leakage if attention_mask is not None: h = h * attention_mask.unsqueeze(-1).type_as(h) # Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context for blk in self.blocks: h = blk(h) # (B, L, D) if attention_mask is not None: h = h * attention_mask.unsqueeze(-1).type_as(h) return h # (B, L, D), directly usable as attention layer input # ============================================================================= # 2. Main Model Architectures # ============================================================================= class TimeAwareGPT2(nn.Module): """ A time-aware GPT-2 model with custom temporal features. """ def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None): super().__init__() self.token_pdrop = token_pdrop self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] self.wte = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd) self.drop = nn.Dropout(pdrop) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.n_embd = n_embd def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: """ Forward pass for the TimeAwareGPT2 model. Args: event_seq (torch.Tensor): Token indices of shape (B, L). time_seq (torch.Tensor): Timestamps for each event of shape (B, L). Returns: torch.Tensor: Logits of shape (B, L, vocab_size). """ B, L = event_seq.size() token_embeddings = self.wte(event_seq) if self.training and self.token_pdrop > 0: drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop token_embeddings[drop_mask] = 0.0 pos_embeddings = self.age_encoder(time_seq.float()) x = self.drop(token_embeddings + pos_embeddings) t_i = time_seq.unsqueeze(-1) t_j = time_seq.unsqueeze(1) time_mask = (t_j < t_i) padding_mask = (event_seq != 0).unsqueeze(1) combined_mask = time_mask & padding_mask is_row_all_zero = ~combined_mask.any(dim=-1) is_not_padding = (event_seq != 0) force_self_attention = is_row_all_zero & is_not_padding combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True for block in self.blocks: x = block(x, custom_mask=combined_mask) x = self.ln_f(x) logits = self.head(x) return logits def get_num_params(self) -> float: """ Returns the number of trainable parameters in the model in millions. """ return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 @torch.no_grad() def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None): """ Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ self.eval() if termination_tokens is None: termination_tokens = [1269] termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) mask_time = -10000 for _ in range(max_new_tokens): logits = self(x, t) logits = logits[:, -1, :] if self.ignore_tokens: logits[:, self.ignore_tokens] = -torch.inf if no_repeat: fill = x.clone() fill[fill == 1] = 0 logits = logits.scatter(1, fill, -torch.inf) t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80) t_next_val, idx_next = t_next_dist.min(1) idx_next = idx_next.unsqueeze(1) age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) x = torch.cat((x, idx_next), dim=1) t = torch.cat((t, age_next), dim=1) if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): break pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) final_logits = self(x, t) x[pad] = 0 t[pad] = mask_time if no_repeat: fill = x.clone() fill[fill == 1] = 0 final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1) return x, t, final_logits class TimeAwareGPT2Learnable(TimeAwareGPT2): """Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.age_encoder = LearnableAgeEncoding( base_dim=self.n_embd, hidden_dim=2 * self.n_embd, final_dim=self.n_embd, ) # ============================================================================= # 3. Loss Function # ============================================================================= class TimeAwareGPT2TemporalConv(nn.Module): """ A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode event and time sequences before Transformer attention blocks. Inputs: - event_seq: (B, L) token ids (0 treated as padding) - time_seq: (B, L) timestamps or step indices (float) Output: - logits: (B, L, vocab_size) """ def __init__( self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: Optional[list[int]] = None, *, conv_layers: int = 2, kernel_size: int = 5, conv_dropout: float = 0.1, fourier_dim: int = 32, pad_id: int = 0, ): super().__init__() self.token_pdrop = token_pdrop self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] self.n_embd = n_embd # Temporal convolutional encoder to build inputs_embeds self.temporal_encoder = TemporalConvEncoder( vocab_size=vocab_size, d_model=n_embd, n_layers=conv_layers, kernel_size=kernel_size, dropout=conv_dropout, fourier_dim=fourier_dim, pad_id=pad_id, ) # Transformer stack on top of temporal features self.drop = nn.Dropout(pdrop) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: B, L = event_seq.size() # Encoder features as inputs_embeds attention_mask = (event_seq != 0) x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask) x = self.drop(x) # Time-aware causal mask as before t_i = time_seq.unsqueeze(-1) t_j = time_seq.unsqueeze(1) time_mask = (t_j < t_i) padding_mask = (event_seq != 0).unsqueeze(1) combined_mask = time_mask & padding_mask # Ensure at least self-attention on non-padding rows is_row_all_zero = ~combined_mask.any(dim=-1) is_not_padding = (event_seq != 0) force_self_attention = is_row_all_zero & is_not_padding combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True for block in self.blocks: x = block(x, custom_mask=combined_mask) x = self.ln_f(x) logits = self.head(x) return logits def get_num_params(self) -> float: return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 @torch.no_grad() def generate( self, x: torch.Tensor, t: torch.Tensor, max_new_tokens: int = 100, max_age: float = 85 * 365.25, no_repeat: bool = True, termination_tokens: Optional[list[int]] = None, top_k: Optional[int] = None, ): """Greedy-like generation with optional no-repeat and termination tokens.""" self.eval() if termination_tokens is None: termination_tokens = [1269] termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) mask_time = -10000 for _ in range(max_new_tokens): logits = self(x, t) logits = logits[:, -1, :] if self.ignore_tokens: logits[:, self.ignore_tokens] = -torch.inf if no_repeat: fill = x.clone() fill[fill == 1] = 0 logits = logits.scatter(1, fill, -torch.inf) # Sample a time increment proxy as in original implementation t_next_dist = torch.clamp( -torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365 * 80, ) t_next_val, idx_next = t_next_dist.min(1) idx_next = idx_next.unsqueeze(1) age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) x = torch.cat((x, idx_next), dim=1) t = torch.cat((t, age_next), dim=1) if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): break pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) final_logits = self(x, t) x[pad] = 0 t[pad] = mask_time if no_repeat: fill = x.clone() fill[fill == 1] = 0 final_logits = torch.stack( [final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])] ).transpose(0, 1) return x, t, final_logits class CombinedLoss(nn.Module): """ Computes a two-part loss: a standard cross-entropy loss for event type prediction and a survival analysis loss for event timing. """ def __init__(self, ignored_token_ids: list[int]): """ Initializes the CombinedLoss module. Args: ignored_token_ids (list[int]): A list of event type IDs to be excluded from all loss calculations. """ super().__init__() self.ignored_token_ids = ignored_token_ids def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculates the combined cross-entropy and survival loss. Args: logits (torch.Tensor): Raw model outputs of shape (B, L, N). x (torch.Tensor): Ground-truth event labels of shape (B, L). t (torch.Tensor): True time duration for each event, shape (B, L). Returns: A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). """ mask = torch.ones_like(x, dtype=torch.bool) for token_id in self.ignored_token_ids: mask = mask & (x != token_id) if not mask.any(): return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) logits_for_ce = logits.permute(0, 2, 1) per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') loss_ce = per_element_ce[mask].mean() # Survival loss based on exponential log-likelihood t_min = 0.1 lse = torch.logsumexp(logits, dim=-1) lse = -torch.log(torch.exp(-lse) + t_min) ldt = -torch.log(t + t_min) loss_dt = -(lse - torch.exp(lse - ldt)) loss_survival = loss_dt[mask].mean() return loss_ce, loss_survival