update models and training scripts

This commit is contained in:
2025-10-22 08:36:55 +08:00
parent e348086e52
commit bd88daa8c2
2 changed files with 56 additions and 90 deletions

126
models.py
View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from typing import Tuple, Optional
# =============================================================================
# 1. Component Modules (Building Blocks)
@@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module):
output[:, :, 1::2] = torch.sin(args)
return output
class LearnableAgeEncoding(nn.Module):
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
self.base_dim = base_dim
self.final_dim = final_dim or base_dim
hidden_dim = hidden_dim or base_dim
if hidden_dim <= 0:
raise ValueError("hidden_dim must be a positive integer.")
if self.final_dim <= 0:
raise ValueError("final_dim must be a positive integer.")
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
mlp_layers = [
nn.Linear(base_dim, hidden_dim),
nn.GELU(),
]
if dropout > 0.0:
mlp_layers.append(nn.Dropout(dropout))
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, t: torch.Tensor) -> torch.Tensor:
sin_embed = self.sinusoidal(t)
flat_embed = sin_embed.reshape(-1, self.base_dim)
projected = self.mlp(flat_embed)
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
class PiecewiseLinearEncoder(nn.Module):
"""
Encodes continuous variables using piecewise linear encoding.
@@ -287,94 +320,19 @@ class TimeAwareGPT2(nn.Module):
return x, t, final_logits
class CovariateAwareGPT2(nn.Module):
"""
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
pdrop: float, token_pdrop: float, num_bins: int):
"""
Initializes the CovariateAwareGPT2 model.
class TimeAwareGPT2Learnable(TimeAwareGPT2):
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
Args:
vocab_size (int): Size of the event vocabulary.
n_embd (int): Embedding dimensionality.
n_layer (int): Number of transformer layers.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability for layers.
token_pdrop (float): Dropout probability for input token embeddings.
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
"""
super().__init__()
self.token_pdrop = token_pdrop
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.n_embd = n_embd
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
self.ln_f = nn.LayerNorm(2 * n_embd)
self.head = nn.Sequential(
nn.Linear(2 * n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, vocab_size)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.age_encoder = LearnableAgeEncoding(
base_dim=self.n_embd,
hidden_dim=2 * self.n_embd,
final_dim=self.n_embd,
)
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CovariateAwareGPT2 model.
Args:
x (torch.Tensor): Event sequence tensor of shape (B, L).
t (torch.Tensor): Time sequence tensor of shape (B, L).
cov (torch.Tensor): Covariate tensor of shape (B, N).
cov_t (torch.Tensor): Covariate time tensor of shape (B).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = x.size()
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
cov_embed = cov_encoded + cov_t_encoded
token_embeddings = self.wte(x)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(t.float())
seq_embed = self.drop(token_embeddings + pos_embeddings)
t_i = t.unsqueeze(-1)
t_j = t.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (x != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (x != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
block_output = seq_embed
for block in self.blocks:
block_output = block(block_output, custom_mask=combined_mask)
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
final_output = self.ln_f(integrated_embed)
logits = self.head(final_output)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
# =============================================================================
# 3. Loss Function