update models and training scripts
This commit is contained in:
126
models.py
126
models.py
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
# =============================================================================
|
||||
# 1. Component Modules (Building Blocks)
|
||||
@@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module):
|
||||
output[:, :, 1::2] = torch.sin(args)
|
||||
return output
|
||||
|
||||
|
||||
class LearnableAgeEncoding(nn.Module):
|
||||
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
|
||||
|
||||
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.base_dim = base_dim
|
||||
self.final_dim = final_dim or base_dim
|
||||
|
||||
hidden_dim = hidden_dim or base_dim
|
||||
if hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be a positive integer.")
|
||||
if self.final_dim <= 0:
|
||||
raise ValueError("final_dim must be a positive integer.")
|
||||
|
||||
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
|
||||
|
||||
mlp_layers = [
|
||||
nn.Linear(base_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
]
|
||||
if dropout > 0.0:
|
||||
mlp_layers.append(nn.Dropout(dropout))
|
||||
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
|
||||
|
||||
self.mlp = nn.Sequential(*mlp_layers)
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
sin_embed = self.sinusoidal(t)
|
||||
flat_embed = sin_embed.reshape(-1, self.base_dim)
|
||||
projected = self.mlp(flat_embed)
|
||||
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
|
||||
|
||||
class PiecewiseLinearEncoder(nn.Module):
|
||||
"""
|
||||
Encodes continuous variables using piecewise linear encoding.
|
||||
@@ -287,94 +320,19 @@ class TimeAwareGPT2(nn.Module):
|
||||
|
||||
return x, t, final_logits
|
||||
|
||||
class CovariateAwareGPT2(nn.Module):
|
||||
"""
|
||||
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
|
||||
pdrop: float, token_pdrop: float, num_bins: int):
|
||||
"""
|
||||
Initializes the CovariateAwareGPT2 model.
|
||||
class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
||||
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the event vocabulary.
|
||||
n_embd (int): Embedding dimensionality.
|
||||
n_layer (int): Number of transformer layers.
|
||||
n_head (int): Number of attention heads.
|
||||
pdrop (float): Dropout probability for layers.
|
||||
token_pdrop (float): Dropout probability for input token embeddings.
|
||||
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
|
||||
"""
|
||||
super().__init__()
|
||||
self.token_pdrop = token_pdrop
|
||||
|
||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||
self.drop = nn.Dropout(pdrop)
|
||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||
self.n_embd = n_embd
|
||||
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
|
||||
|
||||
self.ln_f = nn.LayerNorm(2 * n_embd)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(2 * n_embd, n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(n_embd, vocab_size)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.age_encoder = LearnableAgeEncoding(
|
||||
base_dim=self.n_embd,
|
||||
hidden_dim=2 * self.n_embd,
|
||||
final_dim=self.n_embd,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the CovariateAwareGPT2 model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Event sequence tensor of shape (B, L).
|
||||
t (torch.Tensor): Time sequence tensor of shape (B, L).
|
||||
cov (torch.Tensor): Covariate tensor of shape (B, N).
|
||||
cov_t (torch.Tensor): Covariate time tensor of shape (B).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits of shape (B, L, vocab_size).
|
||||
"""
|
||||
B, L = x.size()
|
||||
|
||||
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
|
||||
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
|
||||
cov_embed = cov_encoded + cov_t_encoded
|
||||
|
||||
token_embeddings = self.wte(x)
|
||||
if self.training and self.token_pdrop > 0:
|
||||
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||
token_embeddings[drop_mask] = 0.0
|
||||
|
||||
pos_embeddings = self.age_encoder(t.float())
|
||||
seq_embed = self.drop(token_embeddings + pos_embeddings)
|
||||
|
||||
t_i = t.unsqueeze(-1)
|
||||
t_j = t.unsqueeze(1)
|
||||
time_mask = (t_j < t_i)
|
||||
padding_mask = (x != 0).unsqueeze(1)
|
||||
combined_mask = time_mask & padding_mask
|
||||
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||
is_not_padding = (x != 0)
|
||||
force_self_attention = is_row_all_zero & is_not_padding
|
||||
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||
|
||||
block_output = seq_embed
|
||||
for block in self.blocks:
|
||||
block_output = block(block_output, custom_mask=combined_mask)
|
||||
|
||||
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
|
||||
|
||||
final_output = self.ln_f(integrated_embed)
|
||||
logits = self.head(final_output)
|
||||
return logits
|
||||
|
||||
def get_num_params(self) -> float:
|
||||
"""
|
||||
Returns the number of trainable parameters in the model in millions.
|
||||
"""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||
|
||||
# =============================================================================
|
||||
# 3. Loss Function
|
||||
|
Reference in New Issue
Block a user