diff --git a/config_n_embd_120_n_layer_12_n_head_12.json b/config_n_embd_120_n_layer_12_n_head_12.json index 633a6e1..16e581a 100644 --- a/config_n_embd_120_n_layer_12_n_head_12.json +++ b/config_n_embd_120_n_layer_12_n_head_12.json @@ -1,4 +1,5 @@ { + "model_name": "TimeAwareGPT2", "n_layer": 12, "n_embd": 120, "n_head": 12, diff --git a/config_n_embd_256_n_layer_16_n_head_16.json b/config_n_embd_256_n_layer_16_n_head_16.json index b49d775..1e34c3d 100644 --- a/config_n_embd_256_n_layer_16_n_head_16.json +++ b/config_n_embd_256_n_layer_16_n_head_16.json @@ -1,4 +1,5 @@ { + "model_name": "TimeAwareGPT2", "n_layer": 16, "n_embd": 256, "n_head": 16, diff --git a/models.py b/models.py index 070f793..be27593 100644 --- a/models.py +++ b/models.py @@ -2,10 +2,81 @@ import torch import torch.nn as nn from torch.nn import functional as F from typing import Tuple, Optional +import math # ============================================================================= # 1. Component Modules (Building Blocks) # ============================================================================= +class CausalConv1d(nn.Module): + def __init__(self, channels, kernel_size, groups=1): + super().__init__() + self.pad = kernel_size - 1 + self.conv = nn.Conv1d( + channels, channels, kernel_size, + padding=0, groups=groups + ) + def forward(self, x): # x: (B, C, L) + x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality + return self.conv(x) + +class DepthwiseSeparableCausalConvBlock(nn.Module): + def __init__(self, d_model, kernel_size=5, dropout=0.1): + super().__init__() + self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise + self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise + self.act = nn.GELU() + self.ln = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): # x: (B, L, D) + y = x.transpose(1, 2) # (B, D, L) + y = self.dw(y) # (B, D, L) + y = self.pw(y) # (B, D, L) + y = y.transpose(1, 2) # (B, L, D) + y = self.act(y) + y = self.dropout(y) + return self.ln(x + y) # residual connection + layer norm (LN) + +class TimeFeatureProjector(nn.Module): + """ + Projects scalar time t and its increment Δt into d_model dimensions. + Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features). + """ + def __init__(self, d_model, fourier_dim=32, dt_clip=1e6): + super().__init__() + self.dt_clip = dt_clip + self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D + + # Predefine a set of logarithmically spaced frequencies (tune for your time units if needed) + k = fourier_dim // 2 + freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2 + self.register_buffer("freqs", freqs, persistent=False) + + self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D + self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features + self.ln = nn.LayerNorm(d_model) + + def forward(self, t): # t: (B, L) continuous timestamps/steps + # compute increments Δt and stabilize + dt = t - F.pad(t, (1, 0), value=0.)[:, :-1] + dt = torch.clamp(dt, min=0.) # ensure non-negative + # normalize/stabilize with log compression + t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip)) + dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip)) + + scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2) + scal_feat = self.scalar_proj(scal) # (B, L, D) + + # Fixed-frequency sin/cos to capture absolute/relative periodicity + # If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant) + # (B, L, K) + wt = t[..., None] * self.freqs + sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K) + fourier_feat = self.fourier_proj(sincos) # (B, L, D) + + # gated fusion + layer norm + h = scal_feat + torch.tanh(self.gate) * fourier_feat + return self.ln(h) # (B, L, D) class Block(nn.Module): """ an unassuming Transformer block """ @@ -200,6 +271,61 @@ class PiecewiseLinearEncoder(nn.Module): encoded = encoded.view(*original_shape, self.num_bins) output = self.linear(encoded) return output + +class TemporalConvEncoder(nn.Module): + """ + Inputs: + x: (B, L) - event/token ids + t: (B, L) - timestamps (real-valued) or step indices + Output: + h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds + """ + def __init__( + self, + vocab_size: int, + d_model: int = 768, + n_layers: int = 2, + kernel_size: int = 5, + dropout: float = 0.1, + fourier_dim: int = 32, + pad_id: int = 0 + ): + super().__init__() + self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) + self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim) + self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features + self.ln_in = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + blocks = [] + for _ in range(n_layers): + blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x, t, attention_mask=None): + """ + attention_mask: (B, L) 1=keep, 0=padding + """ + tok = self.token_emb(x) # (B, L, D) + tim = self.time_proj(t) # (B, L, D) + + h = torch.cat([tok, tim], dim=-1) # (B, L, 2D) + h = self.fuse(h) # (B, L, D) + h = self.ln_in(h) + h = self.dropout(h) + + # Optional: zero-out padding positions before convolutions to avoid leakage + if attention_mask is not None: + h = h * attention_mask.unsqueeze(-1).type_as(h) + + # Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context + for blk in self.blocks: + h = blk(h) # (B, L, D) + + if attention_mask is not None: + h = h * attention_mask.unsqueeze(-1).type_as(h) + + return h # (B, L, D), directly usable as attention layer input # ============================================================================= # 2. Main Model Architectures @@ -338,6 +464,152 @@ class TimeAwareGPT2Learnable(TimeAwareGPT2): # 3. Loss Function # ============================================================================= +class TimeAwareGPT2TemporalConv(nn.Module): + """ + A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode + event and time sequences before Transformer attention blocks. + + Inputs: + - event_seq: (B, L) token ids (0 treated as padding) + - time_seq: (B, L) timestamps or step indices (float) + + Output: + - logits: (B, L, vocab_size) + """ + + def __init__( + self, + vocab_size: int, + n_embd: int, + n_layer: int, + n_head: int, + pdrop: float, + token_pdrop: float, + ignore_tokens: Optional[list[int]] = None, + *, + conv_layers: int = 2, + kernel_size: int = 5, + conv_dropout: float = 0.1, + fourier_dim: int = 32, + pad_id: int = 0, + ): + super().__init__() + self.token_pdrop = token_pdrop + self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] + self.n_embd = n_embd + + # Temporal convolutional encoder to build inputs_embeds + self.temporal_encoder = TemporalConvEncoder( + vocab_size=vocab_size, + d_model=n_embd, + n_layers=conv_layers, + kernel_size=kernel_size, + dropout=conv_dropout, + fourier_dim=fourier_dim, + pad_id=pad_id, + ) + + # Transformer stack on top of temporal features + self.drop = nn.Dropout(pdrop) + self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, vocab_size, bias=False) + + def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: + B, L = event_seq.size() + + # Encoder features as inputs_embeds + attention_mask = (event_seq != 0) + x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask) + x = self.drop(x) + + # Time-aware causal mask as before + t_i = time_seq.unsqueeze(-1) + t_j = time_seq.unsqueeze(1) + time_mask = (t_j < t_i) + padding_mask = (event_seq != 0).unsqueeze(1) + combined_mask = time_mask & padding_mask + + # Ensure at least self-attention on non-padding rows + is_row_all_zero = ~combined_mask.any(dim=-1) + is_not_padding = (event_seq != 0) + force_self_attention = is_row_all_zero & is_not_padding + combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True + + for block in self.blocks: + x = block(x, custom_mask=combined_mask) + + x = self.ln_f(x) + logits = self.head(x) + return logits + + def get_num_params(self) -> float: + return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 + + @torch.no_grad() + def generate( + self, + x: torch.Tensor, + t: torch.Tensor, + max_new_tokens: int = 100, + max_age: float = 85 * 365.25, + no_repeat: bool = True, + termination_tokens: Optional[list[int]] = None, + top_k: Optional[int] = None, + ): + """Greedy-like generation with optional no-repeat and termination tokens.""" + self.eval() + + if termination_tokens is None: + termination_tokens = [1269] + + termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) + mask_time = -10000 + + for _ in range(max_new_tokens): + logits = self(x, t) + logits = logits[:, -1, :] + + if self.ignore_tokens: + logits[:, self.ignore_tokens] = -torch.inf + + if no_repeat: + fill = x.clone() + fill[fill == 1] = 0 + logits = logits.scatter(1, fill, -torch.inf) + + # Sample a time increment proxy as in original implementation + t_next_dist = torch.clamp( + -torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), + min=0, + max=365 * 80, + ) + t_next_val, idx_next = t_next_dist.min(1) + + idx_next = idx_next.unsqueeze(1) + age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) + + x = torch.cat((x, idx_next), dim=1) + t = torch.cat((t, age_next), dim=1) + + if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): + break + + pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) + + final_logits = self(x, t) + x[pad] = 0 + t[pad] = mask_time + + if no_repeat: + fill = x.clone() + fill[fill == 1] = 0 + final_logits = torch.stack( + [final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])] + ).transpose(0, 1) + + return x, t, final_logits + class CombinedLoss(nn.Module): """ Computes a two-part loss: a standard cross-entropy loss for event type diff --git a/train.py b/train.py index e1f6e17..9fea9ae 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt import json import argparse -from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss +from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss from utils import PatientEventDataset # --- Configuration --- @@ -60,7 +60,7 @@ def main(): parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') - parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.') + parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.') args = parser.parse_args() @@ -111,6 +111,7 @@ def main(): model_cls = { 'TimeAwareGPT2': TimeAwareGPT2, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, + 'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv, }[config.model_name] model = model_cls( diff --git a/utils.py b/utils.py index f63b37d..58ec980 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,7 @@ import numpy as np import random from collections import defaultdict import json -from models import TimeAwareGPT2, TimeAwareGPT2Learnable +from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv class PatientEventDataset(torch.utils.data.Dataset): @@ -151,6 +151,7 @@ def load_model(config_path: str, device: str = 'cpu'): model_cls = { 'TimeAwareGPT2': TimeAwareGPT2, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, + 'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv, }.get(model_name, TimeAwareGPT2) # 3) Infer checkpoint filename from config