feat(model): add TimeAwareGPT2TemporalConv using TemporalConvEncoder; wire into train.py and utils.load_model; add model_name to configs; translate CN comments and add math import

This commit is contained in:
2025-10-22 17:34:06 +08:00
parent a81da36657
commit 3bef72f50b
5 changed files with 279 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
{
"model_name": "TimeAwareGPT2",
"n_layer": 12,
"n_embd": 120,
"n_head": 12,

View File

@@ -1,4 +1,5 @@
{
"model_name": "TimeAwareGPT2",
"n_layer": 16,
"n_embd": 256,
"n_head": 16,

272
models.py
View File

@@ -2,10 +2,81 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple, Optional
import math
# =============================================================================
# 1. Component Modules (Building Blocks)
# =============================================================================
class CausalConv1d(nn.Module):
def __init__(self, channels, kernel_size, groups=1):
super().__init__()
self.pad = kernel_size - 1
self.conv = nn.Conv1d(
channels, channels, kernel_size,
padding=0, groups=groups
)
def forward(self, x): # x: (B, C, L)
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
return self.conv(x)
class DepthwiseSeparableCausalConvBlock(nn.Module):
def __init__(self, d_model, kernel_size=5, dropout=0.1):
super().__init__()
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
self.act = nn.GELU()
self.ln = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, L, D)
y = x.transpose(1, 2) # (B, D, L)
y = self.dw(y) # (B, D, L)
y = self.pw(y) # (B, D, L)
y = y.transpose(1, 2) # (B, L, D)
y = self.act(y)
y = self.dropout(y)
return self.ln(x + y) # residual connection + layer norm (LN)
class TimeFeatureProjector(nn.Module):
"""
Projects scalar time t and its increment Δt into d_model dimensions.
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
"""
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
super().__init__()
self.dt_clip = dt_clip
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
k = fourier_dim // 2
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
self.register_buffer("freqs", freqs, persistent=False)
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
self.ln = nn.LayerNorm(d_model)
def forward(self, t): # t: (B, L) continuous timestamps/steps
# compute increments Δt and stabilize
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
dt = torch.clamp(dt, min=0.) # ensure non-negative
# normalize/stabilize with log compression
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
scal_feat = self.scalar_proj(scal) # (B, L, D)
# Fixed-frequency sin/cos to capture absolute/relative periodicity
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
# (B, L, K)
wt = t[..., None] * self.freqs
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
# gated fusion + layer norm
h = scal_feat + torch.tanh(self.gate) * fourier_feat
return self.ln(h) # (B, L, D)
class Block(nn.Module):
""" an unassuming Transformer block """
@@ -200,6 +271,61 @@ class PiecewiseLinearEncoder(nn.Module):
encoded = encoded.view(*original_shape, self.num_bins)
output = self.linear(encoded)
return output
class TemporalConvEncoder(nn.Module):
"""
Inputs:
x: (B, L) - event/token ids
t: (B, L) - timestamps (real-valued) or step indices
Output:
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_layers: int = 2,
kernel_size: int = 5,
dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
self.ln_in = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
blocks = []
for _ in range(n_layers):
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
self.blocks = nn.ModuleList(blocks)
def forward(self, x, t, attention_mask=None):
"""
attention_mask: (B, L) 1=keep, 0=padding
"""
tok = self.token_emb(x) # (B, L, D)
tim = self.time_proj(t) # (B, L, D)
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
h = self.fuse(h) # (B, L, D)
h = self.ln_in(h)
h = self.dropout(h)
# Optional: zero-out padding positions before convolutions to avoid leakage
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
for blk in self.blocks:
h = blk(h) # (B, L, D)
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
return h # (B, L, D), directly usable as attention layer input
# =============================================================================
# 2. Main Model Architectures
@@ -338,6 +464,152 @@ class TimeAwareGPT2Learnable(TimeAwareGPT2):
# 3. Loss Function
# =============================================================================
class TimeAwareGPT2TemporalConv(nn.Module):
"""
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
event and time sequences before Transformer attention blocks.
Inputs:
- event_seq: (B, L) token ids (0 treated as padding)
- time_seq: (B, L) timestamps or step indices (float)
Output:
- logits: (B, L, vocab_size)
"""
def __init__(
self,
vocab_size: int,
n_embd: int,
n_layer: int,
n_head: int,
pdrop: float,
token_pdrop: float,
ignore_tokens: Optional[list[int]] = None,
*,
conv_layers: int = 2,
kernel_size: int = 5,
conv_dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0,
):
super().__init__()
self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.n_embd = n_embd
# Temporal convolutional encoder to build inputs_embeds
self.temporal_encoder = TemporalConvEncoder(
vocab_size=vocab_size,
d_model=n_embd,
n_layers=conv_layers,
kernel_size=kernel_size,
dropout=conv_dropout,
fourier_dim=fourier_dim,
pad_id=pad_id,
)
# Transformer stack on top of temporal features
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
B, L = event_seq.size()
# Encoder features as inputs_embeds
attention_mask = (event_seq != 0)
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
x = self.drop(x)
# Time-aware causal mask as before
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (event_seq != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
# Ensure at least self-attention on non-padding rows
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks:
x = block(x, custom_mask=combined_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
def get_num_params(self) -> float:
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(
self,
x: torch.Tensor,
t: torch.Tensor,
max_new_tokens: int = 100,
max_age: float = 85 * 365.25,
no_repeat: bool = True,
termination_tokens: Optional[list[int]] = None,
top_k: Optional[int] = None,
):
"""Greedy-like generation with optional no-repeat and termination tokens."""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
# Sample a time increment proxy as in original implementation
t_next_dist = torch.clamp(
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
min=0,
max=365 * 80,
)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack(
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
).transpose(0, 1)
return x, t, final_logits
class CombinedLoss(nn.Module):
"""
Computes a two-part loss: a standard cross-entropy loss for event type

View File

@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
@@ -60,7 +60,7 @@ def main():
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.')
args = parser.parse_args()
@@ -111,6 +111,7 @@ def main():
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}[config.model_name]
model = model_cls(

View File

@@ -4,7 +4,7 @@ import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv
class PatientEventDataset(torch.utils.data.Dataset):
@@ -151,6 +151,7 @@ def load_model(config_path: str, device: str = 'cpu'):
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config