feat(model): add TimeAwareGPT2TemporalConv using TemporalConvEncoder; wire into train.py and utils.load_model; add model_name to configs; translate CN comments and add math import
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
|
"model_name": "TimeAwareGPT2",
|
||||||
"n_layer": 12,
|
"n_layer": 12,
|
||||||
"n_embd": 120,
|
"n_embd": 120,
|
||||||
"n_head": 12,
|
"n_head": 12,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
|
"model_name": "TimeAwareGPT2",
|
||||||
"n_layer": 16,
|
"n_layer": 16,
|
||||||
"n_embd": 256,
|
"n_embd": 256,
|
||||||
"n_head": 16,
|
"n_head": 16,
|
||||||
|
|||||||
272
models.py
272
models.py
@@ -2,10 +2,81 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
import math
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 1. Component Modules (Building Blocks)
|
# 1. Component Modules (Building Blocks)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.pad = kernel_size - 1
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
channels, channels, kernel_size,
|
||||||
|
padding=0, groups=groups
|
||||||
|
)
|
||||||
|
def forward(self, x): # x: (B, C, L)
|
||||||
|
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
||||||
|
def __init__(self, d_model, kernel_size=5, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
|
||||||
|
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.ln = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x): # x: (B, L, D)
|
||||||
|
y = x.transpose(1, 2) # (B, D, L)
|
||||||
|
y = self.dw(y) # (B, D, L)
|
||||||
|
y = self.pw(y) # (B, D, L)
|
||||||
|
y = y.transpose(1, 2) # (B, L, D)
|
||||||
|
y = self.act(y)
|
||||||
|
y = self.dropout(y)
|
||||||
|
return self.ln(x + y) # residual connection + layer norm (LN)
|
||||||
|
|
||||||
|
class TimeFeatureProjector(nn.Module):
|
||||||
|
"""
|
||||||
|
Projects scalar time t and its increment Δt into d_model dimensions.
|
||||||
|
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
|
||||||
|
super().__init__()
|
||||||
|
self.dt_clip = dt_clip
|
||||||
|
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
|
||||||
|
|
||||||
|
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
|
||||||
|
k = fourier_dim // 2
|
||||||
|
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
|
||||||
|
self.register_buffer("freqs", freqs, persistent=False)
|
||||||
|
|
||||||
|
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
|
||||||
|
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
|
||||||
|
self.ln = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, t): # t: (B, L) continuous timestamps/steps
|
||||||
|
# compute increments Δt and stabilize
|
||||||
|
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
|
||||||
|
dt = torch.clamp(dt, min=0.) # ensure non-negative
|
||||||
|
# normalize/stabilize with log compression
|
||||||
|
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
|
||||||
|
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
|
||||||
|
|
||||||
|
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
|
||||||
|
scal_feat = self.scalar_proj(scal) # (B, L, D)
|
||||||
|
|
||||||
|
# Fixed-frequency sin/cos to capture absolute/relative periodicity
|
||||||
|
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
|
||||||
|
# (B, L, K)
|
||||||
|
wt = t[..., None] * self.freqs
|
||||||
|
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
|
||||||
|
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
|
||||||
|
|
||||||
|
# gated fusion + layer norm
|
||||||
|
h = scal_feat + torch.tanh(self.gate) * fourier_feat
|
||||||
|
return self.ln(h) # (B, L, D)
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
""" an unassuming Transformer block """
|
""" an unassuming Transformer block """
|
||||||
@@ -200,6 +271,61 @@ class PiecewiseLinearEncoder(nn.Module):
|
|||||||
encoded = encoded.view(*original_shape, self.num_bins)
|
encoded = encoded.view(*original_shape, self.num_bins)
|
||||||
output = self.linear(encoded)
|
output = self.linear(encoded)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class TemporalConvEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
x: (B, L) - event/token ids
|
||||||
|
t: (B, L) - timestamps (real-valued) or step indices
|
||||||
|
Output:
|
||||||
|
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
d_model: int = 768,
|
||||||
|
n_layers: int = 2,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
fourier_dim: int = 32,
|
||||||
|
pad_id: int = 0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
|
||||||
|
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
|
||||||
|
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
|
||||||
|
self.ln_in = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
def forward(self, x, t, attention_mask=None):
|
||||||
|
"""
|
||||||
|
attention_mask: (B, L) 1=keep, 0=padding
|
||||||
|
"""
|
||||||
|
tok = self.token_emb(x) # (B, L, D)
|
||||||
|
tim = self.time_proj(t) # (B, L, D)
|
||||||
|
|
||||||
|
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
|
||||||
|
h = self.fuse(h) # (B, L, D)
|
||||||
|
h = self.ln_in(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
|
||||||
|
# Optional: zero-out padding positions before convolutions to avoid leakage
|
||||||
|
if attention_mask is not None:
|
||||||
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||||
|
|
||||||
|
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
|
||||||
|
for blk in self.blocks:
|
||||||
|
h = blk(h) # (B, L, D)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||||
|
|
||||||
|
return h # (B, L, D), directly usable as attention layer input
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 2. Main Model Architectures
|
# 2. Main Model Architectures
|
||||||
@@ -338,6 +464,152 @@ class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
|||||||
# 3. Loss Function
|
# 3. Loss Function
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
class TimeAwareGPT2TemporalConv(nn.Module):
|
||||||
|
"""
|
||||||
|
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
|
||||||
|
event and time sequences before Transformer attention blocks.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- event_seq: (B, L) token ids (0 treated as padding)
|
||||||
|
- time_seq: (B, L) timestamps or step indices (float)
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- logits: (B, L, vocab_size)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
n_embd: int,
|
||||||
|
n_layer: int,
|
||||||
|
n_head: int,
|
||||||
|
pdrop: float,
|
||||||
|
token_pdrop: float,
|
||||||
|
ignore_tokens: Optional[list[int]] = None,
|
||||||
|
*,
|
||||||
|
conv_layers: int = 2,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
conv_dropout: float = 0.1,
|
||||||
|
fourier_dim: int = 32,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_pdrop = token_pdrop
|
||||||
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||||
|
self.n_embd = n_embd
|
||||||
|
|
||||||
|
# Temporal convolutional encoder to build inputs_embeds
|
||||||
|
self.temporal_encoder = TemporalConvEncoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=n_embd,
|
||||||
|
n_layers=conv_layers,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dropout=conv_dropout,
|
||||||
|
fourier_dim=fourier_dim,
|
||||||
|
pad_id=pad_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer stack on top of temporal features
|
||||||
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, L = event_seq.size()
|
||||||
|
|
||||||
|
# Encoder features as inputs_embeds
|
||||||
|
attention_mask = (event_seq != 0)
|
||||||
|
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
|
||||||
|
x = self.drop(x)
|
||||||
|
|
||||||
|
# Time-aware causal mask as before
|
||||||
|
t_i = time_seq.unsqueeze(-1)
|
||||||
|
t_j = time_seq.unsqueeze(1)
|
||||||
|
time_mask = (t_j < t_i)
|
||||||
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
||||||
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
|
# Ensure at least self-attention on non-padding rows
|
||||||
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||||
|
is_not_padding = (event_seq != 0)
|
||||||
|
force_self_attention = is_row_all_zero & is_not_padding
|
||||||
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
logits = self.head(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_num_params(self) -> float:
|
||||||
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
max_age: float = 85 * 365.25,
|
||||||
|
no_repeat: bool = True,
|
||||||
|
termination_tokens: Optional[list[int]] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Greedy-like generation with optional no-repeat and termination tokens."""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if termination_tokens is None:
|
||||||
|
termination_tokens = [1269]
|
||||||
|
|
||||||
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||||
|
mask_time = -10000
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
logits = self(x, t)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.ignore_tokens:
|
||||||
|
logits[:, self.ignore_tokens] = -torch.inf
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
logits = logits.scatter(1, fill, -torch.inf)
|
||||||
|
|
||||||
|
# Sample a time increment proxy as in original implementation
|
||||||
|
t_next_dist = torch.clamp(
|
||||||
|
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
|
||||||
|
min=0,
|
||||||
|
max=365 * 80,
|
||||||
|
)
|
||||||
|
t_next_val, idx_next = t_next_dist.min(1)
|
||||||
|
|
||||||
|
idx_next = idx_next.unsqueeze(1)
|
||||||
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||||
|
|
||||||
|
x = torch.cat((x, idx_next), dim=1)
|
||||||
|
t = torch.cat((t, age_next), dim=1)
|
||||||
|
|
||||||
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||||
|
|
||||||
|
final_logits = self(x, t)
|
||||||
|
x[pad] = 0
|
||||||
|
t[pad] = mask_time
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
final_logits = torch.stack(
|
||||||
|
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
return x, t, final_logits
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Computes a two-part loss: a standard cross-entropy loss for event type
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||||
|
|||||||
5
train.py
5
train.py
@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
|||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss
|
||||||
from utils import PatientEventDataset
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
@@ -60,7 +60,7 @@ def main():
|
|||||||
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||||
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||||
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||||
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
|
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -111,6 +111,7 @@ def main():
|
|||||||
model_cls = {
|
model_cls = {
|
||||||
'TimeAwareGPT2': TimeAwareGPT2,
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
|
||||||
}[config.model_name]
|
}[config.model_name]
|
||||||
|
|
||||||
model = model_cls(
|
model = model_cls(
|
||||||
|
|||||||
3
utils.py
3
utils.py
@@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv
|
||||||
|
|
||||||
|
|
||||||
class PatientEventDataset(torch.utils.data.Dataset):
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
@@ -151,6 +151,7 @@ def load_model(config_path: str, device: str = 'cpu'):
|
|||||||
model_cls = {
|
model_cls = {
|
||||||
'TimeAwareGPT2': TimeAwareGPT2,
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
|
||||||
}.get(model_name, TimeAwareGPT2)
|
}.get(model_name, TimeAwareGPT2)
|
||||||
|
|
||||||
# 3) Infer checkpoint filename from config
|
# 3) Infer checkpoint filename from config
|
||||||
|
|||||||
Reference in New Issue
Block a user