feat(model): add TimeAwareGPT2TemporalConv using TemporalConvEncoder; wire into train.py and utils.load_model; add model_name to configs; translate CN comments and add math import

This commit is contained in:
2025-10-22 17:34:06 +08:00
parent a81da36657
commit 3bef72f50b
5 changed files with 279 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
{ {
"model_name": "TimeAwareGPT2",
"n_layer": 12, "n_layer": 12,
"n_embd": 120, "n_embd": 120,
"n_head": 12, "n_head": 12,

View File

@@ -1,4 +1,5 @@
{ {
"model_name": "TimeAwareGPT2",
"n_layer": 16, "n_layer": 16,
"n_embd": 256, "n_embd": 256,
"n_head": 16, "n_head": 16,

272
models.py
View File

@@ -2,10 +2,81 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple, Optional from typing import Tuple, Optional
import math
# ============================================================================= # =============================================================================
# 1. Component Modules (Building Blocks) # 1. Component Modules (Building Blocks)
# ============================================================================= # =============================================================================
class CausalConv1d(nn.Module):
def __init__(self, channels, kernel_size, groups=1):
super().__init__()
self.pad = kernel_size - 1
self.conv = nn.Conv1d(
channels, channels, kernel_size,
padding=0, groups=groups
)
def forward(self, x): # x: (B, C, L)
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
return self.conv(x)
class DepthwiseSeparableCausalConvBlock(nn.Module):
def __init__(self, d_model, kernel_size=5, dropout=0.1):
super().__init__()
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
self.act = nn.GELU()
self.ln = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, L, D)
y = x.transpose(1, 2) # (B, D, L)
y = self.dw(y) # (B, D, L)
y = self.pw(y) # (B, D, L)
y = y.transpose(1, 2) # (B, L, D)
y = self.act(y)
y = self.dropout(y)
return self.ln(x + y) # residual connection + layer norm (LN)
class TimeFeatureProjector(nn.Module):
"""
Projects scalar time t and its increment Δt into d_model dimensions.
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
"""
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
super().__init__()
self.dt_clip = dt_clip
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
k = fourier_dim // 2
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
self.register_buffer("freqs", freqs, persistent=False)
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
self.ln = nn.LayerNorm(d_model)
def forward(self, t): # t: (B, L) continuous timestamps/steps
# compute increments Δt and stabilize
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
dt = torch.clamp(dt, min=0.) # ensure non-negative
# normalize/stabilize with log compression
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
scal_feat = self.scalar_proj(scal) # (B, L, D)
# Fixed-frequency sin/cos to capture absolute/relative periodicity
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
# (B, L, K)
wt = t[..., None] * self.freqs
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
# gated fusion + layer norm
h = scal_feat + torch.tanh(self.gate) * fourier_feat
return self.ln(h) # (B, L, D)
class Block(nn.Module): class Block(nn.Module):
""" an unassuming Transformer block """ """ an unassuming Transformer block """
@@ -200,6 +271,61 @@ class PiecewiseLinearEncoder(nn.Module):
encoded = encoded.view(*original_shape, self.num_bins) encoded = encoded.view(*original_shape, self.num_bins)
output = self.linear(encoded) output = self.linear(encoded)
return output return output
class TemporalConvEncoder(nn.Module):
"""
Inputs:
x: (B, L) - event/token ids
t: (B, L) - timestamps (real-valued) or step indices
Output:
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_layers: int = 2,
kernel_size: int = 5,
dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
self.ln_in = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
blocks = []
for _ in range(n_layers):
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
self.blocks = nn.ModuleList(blocks)
def forward(self, x, t, attention_mask=None):
"""
attention_mask: (B, L) 1=keep, 0=padding
"""
tok = self.token_emb(x) # (B, L, D)
tim = self.time_proj(t) # (B, L, D)
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
h = self.fuse(h) # (B, L, D)
h = self.ln_in(h)
h = self.dropout(h)
# Optional: zero-out padding positions before convolutions to avoid leakage
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
for blk in self.blocks:
h = blk(h) # (B, L, D)
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
return h # (B, L, D), directly usable as attention layer input
# ============================================================================= # =============================================================================
# 2. Main Model Architectures # 2. Main Model Architectures
@@ -338,6 +464,152 @@ class TimeAwareGPT2Learnable(TimeAwareGPT2):
# 3. Loss Function # 3. Loss Function
# ============================================================================= # =============================================================================
class TimeAwareGPT2TemporalConv(nn.Module):
"""
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
event and time sequences before Transformer attention blocks.
Inputs:
- event_seq: (B, L) token ids (0 treated as padding)
- time_seq: (B, L) timestamps or step indices (float)
Output:
- logits: (B, L, vocab_size)
"""
def __init__(
self,
vocab_size: int,
n_embd: int,
n_layer: int,
n_head: int,
pdrop: float,
token_pdrop: float,
ignore_tokens: Optional[list[int]] = None,
*,
conv_layers: int = 2,
kernel_size: int = 5,
conv_dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0,
):
super().__init__()
self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.n_embd = n_embd
# Temporal convolutional encoder to build inputs_embeds
self.temporal_encoder = TemporalConvEncoder(
vocab_size=vocab_size,
d_model=n_embd,
n_layers=conv_layers,
kernel_size=kernel_size,
dropout=conv_dropout,
fourier_dim=fourier_dim,
pad_id=pad_id,
)
# Transformer stack on top of temporal features
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
B, L = event_seq.size()
# Encoder features as inputs_embeds
attention_mask = (event_seq != 0)
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
x = self.drop(x)
# Time-aware causal mask as before
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (event_seq != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
# Ensure at least self-attention on non-padding rows
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks:
x = block(x, custom_mask=combined_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
def get_num_params(self) -> float:
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(
self,
x: torch.Tensor,
t: torch.Tensor,
max_new_tokens: int = 100,
max_age: float = 85 * 365.25,
no_repeat: bool = True,
termination_tokens: Optional[list[int]] = None,
top_k: Optional[int] = None,
):
"""Greedy-like generation with optional no-repeat and termination tokens."""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
# Sample a time increment proxy as in original implementation
t_next_dist = torch.clamp(
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
min=0,
max=365 * 80,
)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack(
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
).transpose(0, 1)
return x, t, final_logits
class CombinedLoss(nn.Module): class CombinedLoss(nn.Module):
""" """
Computes a two-part loss: a standard cross-entropy loss for event type Computes a two-part loss: a standard cross-entropy loss for event type

View File

@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
import json import json
import argparse import argparse
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
# --- Configuration --- # --- Configuration ---
@@ -60,7 +60,7 @@ def main():
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.') parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.')
args = parser.parse_args() args = parser.parse_args()
@@ -111,6 +111,7 @@ def main():
model_cls = { model_cls = {
'TimeAwareGPT2': TimeAwareGPT2, 'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}[config.model_name] }[config.model_name]
model = model_cls( model = model_cls(

View File

@@ -4,7 +4,7 @@ import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
@@ -151,6 +151,7 @@ def load_model(config_path: str, device: str = 'cpu'):
model_cls = { model_cls = {
'TimeAwareGPT2': TimeAwareGPT2, 'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}.get(model_name, TimeAwareGPT2) }.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config # 3) Infer checkpoint filename from config