feat(model): add TimeAwareGPT2TemporalConv using TemporalConvEncoder; wire into train.py and utils.load_model; add model_name to configs; translate CN comments and add math import
This commit is contained in:
272
models.py
272
models.py
@@ -2,10 +2,81 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Tuple, Optional
|
||||
import math
|
||||
|
||||
# =============================================================================
|
||||
# 1. Component Modules (Building Blocks)
|
||||
# =============================================================================
|
||||
class CausalConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, groups=1):
|
||||
super().__init__()
|
||||
self.pad = kernel_size - 1
|
||||
self.conv = nn.Conv1d(
|
||||
channels, channels, kernel_size,
|
||||
padding=0, groups=groups
|
||||
)
|
||||
def forward(self, x): # x: (B, C, L)
|
||||
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
||||
return self.conv(x)
|
||||
|
||||
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
||||
def __init__(self, d_model, kernel_size=5, dropout=0.1):
|
||||
super().__init__()
|
||||
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
|
||||
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
|
||||
self.act = nn.GELU()
|
||||
self.ln = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x): # x: (B, L, D)
|
||||
y = x.transpose(1, 2) # (B, D, L)
|
||||
y = self.dw(y) # (B, D, L)
|
||||
y = self.pw(y) # (B, D, L)
|
||||
y = y.transpose(1, 2) # (B, L, D)
|
||||
y = self.act(y)
|
||||
y = self.dropout(y)
|
||||
return self.ln(x + y) # residual connection + layer norm (LN)
|
||||
|
||||
class TimeFeatureProjector(nn.Module):
|
||||
"""
|
||||
Projects scalar time t and its increment Δt into d_model dimensions.
|
||||
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
|
||||
"""
|
||||
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
|
||||
super().__init__()
|
||||
self.dt_clip = dt_clip
|
||||
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
|
||||
|
||||
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
|
||||
k = fourier_dim // 2
|
||||
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
|
||||
self.register_buffer("freqs", freqs, persistent=False)
|
||||
|
||||
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
|
||||
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
|
||||
self.ln = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, t): # t: (B, L) continuous timestamps/steps
|
||||
# compute increments Δt and stabilize
|
||||
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
|
||||
dt = torch.clamp(dt, min=0.) # ensure non-negative
|
||||
# normalize/stabilize with log compression
|
||||
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
|
||||
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
|
||||
|
||||
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
|
||||
scal_feat = self.scalar_proj(scal) # (B, L, D)
|
||||
|
||||
# Fixed-frequency sin/cos to capture absolute/relative periodicity
|
||||
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
|
||||
# (B, L, K)
|
||||
wt = t[..., None] * self.freqs
|
||||
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
|
||||
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
|
||||
|
||||
# gated fusion + layer norm
|
||||
h = scal_feat + torch.tanh(self.gate) * fourier_feat
|
||||
return self.ln(h) # (B, L, D)
|
||||
|
||||
class Block(nn.Module):
|
||||
""" an unassuming Transformer block """
|
||||
@@ -200,6 +271,61 @@ class PiecewiseLinearEncoder(nn.Module):
|
||||
encoded = encoded.view(*original_shape, self.num_bins)
|
||||
output = self.linear(encoded)
|
||||
return output
|
||||
|
||||
class TemporalConvEncoder(nn.Module):
|
||||
"""
|
||||
Inputs:
|
||||
x: (B, L) - event/token ids
|
||||
t: (B, L) - timestamps (real-valued) or step indices
|
||||
Output:
|
||||
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int = 768,
|
||||
n_layers: int = 2,
|
||||
kernel_size: int = 5,
|
||||
dropout: float = 0.1,
|
||||
fourier_dim: int = 32,
|
||||
pad_id: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
|
||||
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
|
||||
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
|
||||
self.ln_in = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
blocks = []
|
||||
for _ in range(n_layers):
|
||||
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x, t, attention_mask=None):
|
||||
"""
|
||||
attention_mask: (B, L) 1=keep, 0=padding
|
||||
"""
|
||||
tok = self.token_emb(x) # (B, L, D)
|
||||
tim = self.time_proj(t) # (B, L, D)
|
||||
|
||||
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
|
||||
h = self.fuse(h) # (B, L, D)
|
||||
h = self.ln_in(h)
|
||||
h = self.dropout(h)
|
||||
|
||||
# Optional: zero-out padding positions before convolutions to avoid leakage
|
||||
if attention_mask is not None:
|
||||
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||
|
||||
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
|
||||
for blk in self.blocks:
|
||||
h = blk(h) # (B, L, D)
|
||||
|
||||
if attention_mask is not None:
|
||||
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||
|
||||
return h # (B, L, D), directly usable as attention layer input
|
||||
|
||||
# =============================================================================
|
||||
# 2. Main Model Architectures
|
||||
@@ -338,6 +464,152 @@ class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
||||
# 3. Loss Function
|
||||
# =============================================================================
|
||||
|
||||
class TimeAwareGPT2TemporalConv(nn.Module):
|
||||
"""
|
||||
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
|
||||
event and time sequences before Transformer attention blocks.
|
||||
|
||||
Inputs:
|
||||
- event_seq: (B, L) token ids (0 treated as padding)
|
||||
- time_seq: (B, L) timestamps or step indices (float)
|
||||
|
||||
Output:
|
||||
- logits: (B, L, vocab_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
n_embd: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
pdrop: float,
|
||||
token_pdrop: float,
|
||||
ignore_tokens: Optional[list[int]] = None,
|
||||
*,
|
||||
conv_layers: int = 2,
|
||||
kernel_size: int = 5,
|
||||
conv_dropout: float = 0.1,
|
||||
fourier_dim: int = 32,
|
||||
pad_id: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.token_pdrop = token_pdrop
|
||||
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||
self.n_embd = n_embd
|
||||
|
||||
# Temporal convolutional encoder to build inputs_embeds
|
||||
self.temporal_encoder = TemporalConvEncoder(
|
||||
vocab_size=vocab_size,
|
||||
d_model=n_embd,
|
||||
n_layers=conv_layers,
|
||||
kernel_size=kernel_size,
|
||||
dropout=conv_dropout,
|
||||
fourier_dim=fourier_dim,
|
||||
pad_id=pad_id,
|
||||
)
|
||||
|
||||
# Transformer stack on top of temporal features
|
||||
self.drop = nn.Dropout(pdrop)
|
||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
|
||||
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||
B, L = event_seq.size()
|
||||
|
||||
# Encoder features as inputs_embeds
|
||||
attention_mask = (event_seq != 0)
|
||||
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
|
||||
x = self.drop(x)
|
||||
|
||||
# Time-aware causal mask as before
|
||||
t_i = time_seq.unsqueeze(-1)
|
||||
t_j = time_seq.unsqueeze(1)
|
||||
time_mask = (t_j < t_i)
|
||||
padding_mask = (event_seq != 0).unsqueeze(1)
|
||||
combined_mask = time_mask & padding_mask
|
||||
|
||||
# Ensure at least self-attention on non-padding rows
|
||||
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||
is_not_padding = (event_seq != 0)
|
||||
force_self_attention = is_row_all_zero & is_not_padding
|
||||
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, custom_mask=combined_mask)
|
||||
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
return logits
|
||||
|
||||
def get_num_params(self) -> float:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
max_new_tokens: int = 100,
|
||||
max_age: float = 85 * 365.25,
|
||||
no_repeat: bool = True,
|
||||
termination_tokens: Optional[list[int]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
):
|
||||
"""Greedy-like generation with optional no-repeat and termination tokens."""
|
||||
self.eval()
|
||||
|
||||
if termination_tokens is None:
|
||||
termination_tokens = [1269]
|
||||
|
||||
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||
mask_time = -10000
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
logits = self(x, t)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.ignore_tokens:
|
||||
logits[:, self.ignore_tokens] = -torch.inf
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
logits = logits.scatter(1, fill, -torch.inf)
|
||||
|
||||
# Sample a time increment proxy as in original implementation
|
||||
t_next_dist = torch.clamp(
|
||||
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
|
||||
min=0,
|
||||
max=365 * 80,
|
||||
)
|
||||
t_next_val, idx_next = t_next_dist.min(1)
|
||||
|
||||
idx_next = idx_next.unsqueeze(1)
|
||||
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||
|
||||
x = torch.cat((x, idx_next), dim=1)
|
||||
t = torch.cat((t, age_next), dim=1)
|
||||
|
||||
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||
break
|
||||
|
||||
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||
|
||||
final_logits = self(x, t)
|
||||
x[pad] = 0
|
||||
t[pad] = mask_time
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
final_logits = torch.stack(
|
||||
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
|
||||
).transpose(0, 1)
|
||||
|
||||
return x, t, final_logits
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""
|
||||
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||
|
Reference in New Issue
Block a user