662 lines
24 KiB
Python
662 lines
24 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from typing import Tuple, Optional
|
|
import math
|
|
|
|
# =============================================================================
|
|
# 1. Component Modules (Building Blocks)
|
|
# =============================================================================
|
|
class CausalConv1d(nn.Module):
|
|
def __init__(self, channels, kernel_size, groups=1):
|
|
super().__init__()
|
|
self.pad = kernel_size - 1
|
|
self.conv = nn.Conv1d(
|
|
channels, channels, kernel_size,
|
|
padding=0, groups=groups
|
|
)
|
|
def forward(self, x): # x: (B, C, L)
|
|
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
|
return self.conv(x)
|
|
|
|
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
|
def __init__(self, d_model, kernel_size=5, dropout=0.1):
|
|
super().__init__()
|
|
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
|
|
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
|
|
self.act = nn.GELU()
|
|
self.ln = nn.LayerNorm(d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x): # x: (B, L, D)
|
|
y = x.transpose(1, 2) # (B, D, L)
|
|
y = self.dw(y) # (B, D, L)
|
|
y = self.pw(y) # (B, D, L)
|
|
y = y.transpose(1, 2) # (B, L, D)
|
|
y = self.act(y)
|
|
y = self.dropout(y)
|
|
return self.ln(x + y) # residual connection + layer norm (LN)
|
|
|
|
class TimeFeatureProjector(nn.Module):
|
|
"""
|
|
Projects scalar time t and its increment Δt into d_model dimensions.
|
|
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
|
|
"""
|
|
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
|
|
super().__init__()
|
|
self.dt_clip = dt_clip
|
|
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
|
|
|
|
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
|
|
k = fourier_dim // 2
|
|
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
|
|
self.register_buffer("freqs", freqs, persistent=False)
|
|
|
|
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
|
|
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
|
|
self.ln = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, t): # t: (B, L) continuous timestamps/steps
|
|
# compute increments Δt and stabilize
|
|
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
|
|
dt = torch.clamp(dt, min=0.) # ensure non-negative
|
|
# normalize/stabilize with log compression
|
|
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
|
|
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
|
|
|
|
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
|
|
scal_feat = self.scalar_proj(scal) # (B, L, D)
|
|
|
|
# Fixed-frequency sin/cos to capture absolute/relative periodicity
|
|
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
|
|
# (B, L, K)
|
|
wt = t[..., None] * self.freqs
|
|
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
|
|
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
|
|
|
|
# gated fusion + layer norm
|
|
h = scal_feat + torch.tanh(self.gate) * fourier_feat
|
|
return self.ln(h) # (B, L, D)
|
|
|
|
class Block(nn.Module):
|
|
""" an unassuming Transformer block """
|
|
|
|
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
|
super().__init__()
|
|
self.n_head = n_head
|
|
self.ln_1 = nn.LayerNorm(n_embd)
|
|
self.attn = nn.MultiheadAttention(n_embd, n_head, dropout=pdrop, batch_first=True)
|
|
self.ln_2 = nn.LayerNorm(n_embd)
|
|
self.mlp = nn.ModuleDict(dict(
|
|
c_fc = nn.Linear(n_embd, 4 * n_embd),
|
|
c_proj = nn.Linear(4 * n_embd, n_embd),
|
|
act = nn.GELU(),
|
|
dropout = nn.Dropout(pdrop),
|
|
))
|
|
m = self.mlp
|
|
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
|
|
self.resid_dropout = nn.Dropout(pdrop)
|
|
|
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
|
normed_x = self.ln_1(x)
|
|
|
|
attn_mask = ~custom_mask
|
|
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0)
|
|
|
|
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
|
|
x = x + self.resid_dropout(attn_output)
|
|
x = x + self.mlpf(self.ln_2(x))
|
|
return x
|
|
|
|
class AgeSinusoidalEncoding(nn.Module):
|
|
"""
|
|
Encodes age using sinusoidal functions, similar to positional encodings
|
|
in Transformers. This module creates a fixed-size embedding for an age
|
|
value given in days.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int):
|
|
"""
|
|
Initializes the AgeSinusoidalEncoding module.
|
|
|
|
Args:
|
|
embedding_dim (int): The dimensionality of the output embedding.
|
|
Must be an even number.
|
|
|
|
Raises:
|
|
ValueError: If embedding_dim is not an even number.
|
|
"""
|
|
super().__init__()
|
|
if embedding_dim % 2 != 0:
|
|
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
|
|
|
|
self.embedding_dim = embedding_dim
|
|
|
|
# Pre-calculate the divisor term for the sinusoidal formula.
|
|
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
|
divisor = torch.pow(10000, i / self.embedding_dim)
|
|
self.register_buffer('divisor', divisor)
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass for the AgeSinusoidalEncoding.
|
|
|
|
Args:
|
|
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
|
|
with dtype=torch.float32, representing age in days.
|
|
|
|
Returns:
|
|
torch.Tensor: The encoded age tensor of shape
|
|
(batch_size, sequence_length, embedding_dim).
|
|
"""
|
|
t_years = t / 365.25
|
|
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
|
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
|
output[:, :, 0::2] = torch.cos(args)
|
|
output[:, :, 1::2] = torch.sin(args)
|
|
return output
|
|
|
|
|
|
class LearnableAgeEncoding(nn.Module):
|
|
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
|
|
|
|
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
|
|
super().__init__()
|
|
self.base_dim = base_dim
|
|
self.final_dim = final_dim or base_dim
|
|
|
|
hidden_dim = hidden_dim or base_dim
|
|
if hidden_dim <= 0:
|
|
raise ValueError("hidden_dim must be a positive integer.")
|
|
if self.final_dim <= 0:
|
|
raise ValueError("final_dim must be a positive integer.")
|
|
|
|
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
|
|
|
|
mlp_layers = [
|
|
nn.Linear(base_dim, hidden_dim),
|
|
nn.GELU(),
|
|
]
|
|
if dropout > 0.0:
|
|
mlp_layers.append(nn.Dropout(dropout))
|
|
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
|
|
|
|
self.mlp = nn.Sequential(*mlp_layers)
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
sin_embed = self.sinusoidal(t)
|
|
flat_embed = sin_embed.reshape(-1, self.base_dim)
|
|
projected = self.mlp(flat_embed)
|
|
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
|
|
|
|
class PiecewiseLinearEncoder(nn.Module):
|
|
"""
|
|
Encodes continuous variables using piecewise linear encoding.
|
|
|
|
This module defines bins based on standard normal distribution quantiles,
|
|
encodes an input by finding its bin, and calculates its position as a
|
|
linear interpolation between boundaries. The result is projected to the
|
|
final embedding dimension by a shared linear layer.
|
|
"""
|
|
|
|
def __init__(self, num_bins: int, embedding_dim: int):
|
|
"""
|
|
Initializes the PiecewiseLinearEncoder module.
|
|
|
|
Args:
|
|
num_bins (int): The number of bins for the encoding.
|
|
embedding_dim (int): The dimensionality of the output embedding (D).
|
|
"""
|
|
super().__init__()
|
|
if num_bins <= 0:
|
|
raise ValueError("num_bins must be a positive integer.")
|
|
self.num_bins = num_bins
|
|
self.embedding_dim = embedding_dim
|
|
|
|
if num_bins > 1:
|
|
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
|
|
normal_dist = torch.distributions.normal.Normal(0, 1)
|
|
boundaries = normal_dist.icdf(quantiles)
|
|
else:
|
|
boundaries = torch.tensor([])
|
|
|
|
boundaries = torch.cat([
|
|
torch.tensor([float('-inf')]),
|
|
boundaries,
|
|
torch.tensor([float('inf')])
|
|
])
|
|
self.register_buffer('boundaries', boundaries)
|
|
|
|
self.linear = nn.Linear(num_bins, embedding_dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass for the piecewise linear encoding.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor of shape (*, N), where * is any
|
|
number of batch dimensions and N is the number of continuous
|
|
features. Assumed to be pre-scaled.
|
|
|
|
Returns:
|
|
torch.Tensor: Encoded tensor of shape (*, N, D).
|
|
"""
|
|
original_shape = x.shape
|
|
x = x.reshape(-1, original_shape[-1])
|
|
|
|
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
|
|
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
|
|
|
|
lower_bounds = self.boundaries[bin_indices]
|
|
upper_bounds = self.boundaries[bin_indices + 1]
|
|
delta = upper_bounds - lower_bounds + 1e-8
|
|
|
|
weight_upper = (x - lower_bounds) / delta
|
|
weight_lower = 1.0 - weight_upper
|
|
|
|
is_first_bin = (bin_indices == 0)
|
|
is_last_bin = (bin_indices == self.num_bins - 1)
|
|
|
|
weight_lower[is_first_bin] = 1.0
|
|
weight_upper[is_first_bin] = 0.0
|
|
weight_lower[is_last_bin] = 0.0
|
|
weight_upper[is_last_bin] = 1.0
|
|
|
|
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
|
|
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
|
|
|
|
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
|
|
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
|
|
|
|
encoded = encoded.view(*original_shape, self.num_bins)
|
|
output = self.linear(encoded)
|
|
return output
|
|
|
|
class TemporalConvEncoder(nn.Module):
|
|
"""
|
|
Inputs:
|
|
x: (B, L) - event/token ids
|
|
t: (B, L) - timestamps (real-valued) or step indices
|
|
Output:
|
|
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
|
|
"""
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
d_model: int = 768,
|
|
n_layers: int = 2,
|
|
kernel_size: int = 5,
|
|
dropout: float = 0.1,
|
|
fourier_dim: int = 32,
|
|
pad_id: int = 0
|
|
):
|
|
super().__init__()
|
|
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
|
|
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
|
|
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
|
|
self.ln_in = nn.LayerNorm(d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
blocks = []
|
|
for _ in range(n_layers):
|
|
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
|
|
self.blocks = nn.ModuleList(blocks)
|
|
|
|
def forward(self, x, t, attention_mask=None):
|
|
"""
|
|
attention_mask: (B, L) 1=keep, 0=padding
|
|
"""
|
|
tok = self.token_emb(x) # (B, L, D)
|
|
tim = self.time_proj(t) # (B, L, D)
|
|
|
|
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
|
|
h = self.fuse(h) # (B, L, D)
|
|
h = self.ln_in(h)
|
|
h = self.dropout(h)
|
|
|
|
# Optional: zero-out padding positions before convolutions to avoid leakage
|
|
if attention_mask is not None:
|
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
|
|
|
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
|
|
for blk in self.blocks:
|
|
h = blk(h) # (B, L, D)
|
|
|
|
if attention_mask is not None:
|
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
|
|
|
return h # (B, L, D), directly usable as attention layer input
|
|
|
|
# =============================================================================
|
|
# 2. Main Model Architectures
|
|
# =============================================================================
|
|
|
|
class TimeAwareGPT2(nn.Module):
|
|
"""
|
|
A time-aware GPT-2 model with custom temporal features.
|
|
"""
|
|
|
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
|
|
super().__init__()
|
|
self.token_pdrop = token_pdrop
|
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
|
|
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
|
self.drop = nn.Dropout(pdrop)
|
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
self.n_embd = n_embd
|
|
|
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass for the TimeAwareGPT2 model.
|
|
|
|
Args:
|
|
event_seq (torch.Tensor): Token indices of shape (B, L).
|
|
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
|
|
|
|
Returns:
|
|
torch.Tensor: Logits of shape (B, L, vocab_size).
|
|
"""
|
|
B, L = event_seq.size()
|
|
|
|
token_embeddings = self.wte(event_seq)
|
|
if self.training and self.token_pdrop > 0:
|
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
|
token_embeddings[drop_mask] = 0.0
|
|
|
|
pos_embeddings = self.age_encoder(time_seq.float())
|
|
x = self.drop(token_embeddings + pos_embeddings)
|
|
|
|
t_i = time_seq.unsqueeze(-1)
|
|
t_j = time_seq.unsqueeze(1)
|
|
time_mask = (t_j < t_i)
|
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
|
combined_mask = time_mask & padding_mask
|
|
|
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
|
is_not_padding = (event_seq != 0)
|
|
force_self_attention = is_row_all_zero & is_not_padding
|
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
|
|
|
for block in self.blocks:
|
|
x = block(x, custom_mask=combined_mask)
|
|
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
return logits
|
|
|
|
def get_num_params(self) -> float:
|
|
"""
|
|
Returns the number of trainable parameters in the model in millions.
|
|
"""
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
|
|
|
@torch.no_grad()
|
|
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
|
|
"""
|
|
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
|
|
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
|
"""
|
|
self.eval()
|
|
|
|
if termination_tokens is None:
|
|
termination_tokens = [1269]
|
|
|
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
|
mask_time = -10000
|
|
|
|
for _ in range(max_new_tokens):
|
|
logits = self(x, t)
|
|
logits = logits[:, -1, :]
|
|
|
|
if self.ignore_tokens:
|
|
logits[:, self.ignore_tokens] = -torch.inf
|
|
|
|
if no_repeat:
|
|
fill = x.clone()
|
|
fill[fill == 1] = 0
|
|
logits = logits.scatter(1, fill, -torch.inf)
|
|
|
|
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
|
|
t_next_val, idx_next = t_next_dist.min(1)
|
|
|
|
idx_next = idx_next.unsqueeze(1)
|
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
|
|
|
x = torch.cat((x, idx_next), dim=1)
|
|
t = torch.cat((t, age_next), dim=1)
|
|
|
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
|
break
|
|
|
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
|
|
|
final_logits = self(x, t)
|
|
x[pad] = 0
|
|
t[pad] = mask_time
|
|
|
|
if no_repeat:
|
|
fill = x.clone()
|
|
fill[fill == 1] = 0
|
|
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
|
|
|
|
return x, t, final_logits
|
|
|
|
|
|
class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
|
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.age_encoder = LearnableAgeEncoding(
|
|
base_dim=self.n_embd,
|
|
hidden_dim=2 * self.n_embd,
|
|
final_dim=self.n_embd,
|
|
)
|
|
|
|
|
|
|
|
# =============================================================================
|
|
# 3. Loss Function
|
|
# =============================================================================
|
|
|
|
class TimeAwareGPT2TemporalConv(nn.Module):
|
|
"""
|
|
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
|
|
event and time sequences before Transformer attention blocks.
|
|
|
|
Inputs:
|
|
- event_seq: (B, L) token ids (0 treated as padding)
|
|
- time_seq: (B, L) timestamps or step indices (float)
|
|
|
|
Output:
|
|
- logits: (B, L, vocab_size)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
n_embd: int,
|
|
n_layer: int,
|
|
n_head: int,
|
|
pdrop: float,
|
|
token_pdrop: float,
|
|
ignore_tokens: Optional[list[int]] = None,
|
|
*,
|
|
conv_layers: int = 2,
|
|
kernel_size: int = 5,
|
|
conv_dropout: float = 0.1,
|
|
fourier_dim: int = 32,
|
|
pad_id: int = 0,
|
|
):
|
|
super().__init__()
|
|
self.token_pdrop = token_pdrop
|
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
|
self.n_embd = n_embd
|
|
|
|
# Temporal convolutional encoder to build inputs_embeds
|
|
self.temporal_encoder = TemporalConvEncoder(
|
|
vocab_size=vocab_size,
|
|
d_model=n_embd,
|
|
n_layers=conv_layers,
|
|
kernel_size=kernel_size,
|
|
dropout=conv_dropout,
|
|
fourier_dim=fourier_dim,
|
|
pad_id=pad_id,
|
|
)
|
|
|
|
# Transformer stack on top of temporal features
|
|
self.drop = nn.Dropout(pdrop)
|
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
|
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
|
B, L = event_seq.size()
|
|
|
|
# Encoder features as inputs_embeds
|
|
attention_mask = (event_seq != 0)
|
|
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
|
|
x = self.drop(x)
|
|
|
|
# Time-aware causal mask as before
|
|
t_i = time_seq.unsqueeze(-1)
|
|
t_j = time_seq.unsqueeze(1)
|
|
time_mask = (t_j < t_i)
|
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
|
combined_mask = time_mask & padding_mask
|
|
|
|
# Ensure at least self-attention on non-padding rows
|
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
|
is_not_padding = (event_seq != 0)
|
|
force_self_attention = is_row_all_zero & is_not_padding
|
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
|
|
|
for block in self.blocks:
|
|
x = block(x, custom_mask=combined_mask)
|
|
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
return logits
|
|
|
|
def get_num_params(self) -> float:
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor,
|
|
max_new_tokens: int = 100,
|
|
max_age: float = 85 * 365.25,
|
|
no_repeat: bool = True,
|
|
termination_tokens: Optional[list[int]] = None,
|
|
top_k: Optional[int] = None,
|
|
):
|
|
"""Greedy-like generation with optional no-repeat and termination tokens."""
|
|
self.eval()
|
|
|
|
if termination_tokens is None:
|
|
termination_tokens = [1269]
|
|
|
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
|
mask_time = -10000
|
|
|
|
for _ in range(max_new_tokens):
|
|
logits = self(x, t)
|
|
logits = logits[:, -1, :]
|
|
|
|
if self.ignore_tokens:
|
|
logits[:, self.ignore_tokens] = -torch.inf
|
|
|
|
if no_repeat:
|
|
fill = x.clone()
|
|
fill[fill == 1] = 0
|
|
logits = logits.scatter(1, fill, -torch.inf)
|
|
|
|
# Sample a time increment proxy as in original implementation
|
|
t_next_dist = torch.clamp(
|
|
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
|
|
min=0,
|
|
max=365 * 80,
|
|
)
|
|
t_next_val, idx_next = t_next_dist.min(1)
|
|
|
|
idx_next = idx_next.unsqueeze(1)
|
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
|
|
|
x = torch.cat((x, idx_next), dim=1)
|
|
t = torch.cat((t, age_next), dim=1)
|
|
|
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
|
break
|
|
|
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
|
|
|
final_logits = self(x, t)
|
|
x[pad] = 0
|
|
t[pad] = mask_time
|
|
|
|
if no_repeat:
|
|
fill = x.clone()
|
|
fill[fill == 1] = 0
|
|
final_logits = torch.stack(
|
|
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
|
|
).transpose(0, 1)
|
|
|
|
return x, t, final_logits
|
|
|
|
class CombinedLoss(nn.Module):
|
|
"""
|
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
|
prediction and a survival analysis loss for event timing.
|
|
"""
|
|
|
|
def __init__(self, ignored_token_ids: list[int]):
|
|
"""
|
|
Initializes the CombinedLoss module.
|
|
|
|
Args:
|
|
ignored_token_ids (list[int]): A list of event type IDs to be
|
|
excluded from all loss calculations.
|
|
"""
|
|
super().__init__()
|
|
self.ignored_token_ids = ignored_token_ids
|
|
|
|
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Calculates the combined cross-entropy and survival loss.
|
|
|
|
Args:
|
|
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
|
|
x (torch.Tensor): Ground-truth event labels of shape (B, L).
|
|
t (torch.Tensor): True time duration for each event, shape (B, L).
|
|
|
|
Returns:
|
|
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
|
"""
|
|
mask = torch.ones_like(x, dtype=torch.bool)
|
|
for token_id in self.ignored_token_ids:
|
|
mask = mask & (x != token_id)
|
|
|
|
if not mask.any():
|
|
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
|
|
|
logits_for_ce = logits.permute(0, 2, 1)
|
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
|
loss_ce = per_element_ce[mask].mean()
|
|
|
|
# Survival loss based on exponential log-likelihood
|
|
t_min = 0.1
|
|
lse = torch.logsumexp(logits, dim=-1)
|
|
lse = -torch.log(torch.exp(-lse) + t_min)
|
|
ldt = -torch.log(t + t_min)
|
|
loss_dt = -(lse - torch.exp(lse - ldt))
|
|
loss_survival = loss_dt[mask].mean()
|
|
|
|
return loss_ce, loss_survival
|