diff --git a/models.py b/models.py index 2bd26bb..070f793 100644 --- a/models.py +++ b/models.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from typing import Tuple +from typing import Tuple, Optional # ============================================================================= # 1. Component Modules (Building Blocks) @@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module): output[:, :, 1::2] = torch.sin(args) return output + +class LearnableAgeEncoding(nn.Module): + """Combines fixed sinusoidal age encodings with a learnable MLP projection.""" + + def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0): + super().__init__() + self.base_dim = base_dim + self.final_dim = final_dim or base_dim + + hidden_dim = hidden_dim or base_dim + if hidden_dim <= 0: + raise ValueError("hidden_dim must be a positive integer.") + if self.final_dim <= 0: + raise ValueError("final_dim must be a positive integer.") + + self.sinusoidal = AgeSinusoidalEncoding(base_dim) + + mlp_layers = [ + nn.Linear(base_dim, hidden_dim), + nn.GELU(), + ] + if dropout > 0.0: + mlp_layers.append(nn.Dropout(dropout)) + mlp_layers.append(nn.Linear(hidden_dim, self.final_dim)) + + self.mlp = nn.Sequential(*mlp_layers) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + sin_embed = self.sinusoidal(t) + flat_embed = sin_embed.reshape(-1, self.base_dim) + projected = self.mlp(flat_embed) + return projected.reshape(*sin_embed.shape[:-1], self.final_dim) + class PiecewiseLinearEncoder(nn.Module): """ Encodes continuous variables using piecewise linear encoding. @@ -287,94 +320,19 @@ class TimeAwareGPT2(nn.Module): return x, t, final_logits -class CovariateAwareGPT2(nn.Module): - """ - Extends TimeAwareGPT2 to incorporate static and time-varying covariates. - """ - def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, - pdrop: float, token_pdrop: float, num_bins: int): - """ - Initializes the CovariateAwareGPT2 model. +class TimeAwareGPT2Learnable(TimeAwareGPT2): + """Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features.""" - Args: - vocab_size (int): Size of the event vocabulary. - n_embd (int): Embedding dimensionality. - n_layer (int): Number of transformer layers. - n_head (int): Number of attention heads. - pdrop (float): Dropout probability for layers. - token_pdrop (float): Dropout probability for input token embeddings. - num_bins (int): Number of bins for the PiecewiseLinearEncoder. - """ - super().__init__() - self.token_pdrop = token_pdrop - - self.wte = nn.Embedding(vocab_size, n_embd) - self.age_encoder = AgeSinusoidalEncoding(n_embd) - self.drop = nn.Dropout(pdrop) - self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) - self.n_embd = n_embd - self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd) - - self.ln_f = nn.LayerNorm(2 * n_embd) - self.head = nn.Sequential( - nn.Linear(2 * n_embd, n_embd), - nn.GELU(), - nn.Linear(n_embd, vocab_size) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.age_encoder = LearnableAgeEncoding( + base_dim=self.n_embd, + hidden_dim=2 * self.n_embd, + final_dim=self.n_embd, ) - def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor: - """ - Forward pass for the CovariateAwareGPT2 model. - Args: - x (torch.Tensor): Event sequence tensor of shape (B, L). - t (torch.Tensor): Time sequence tensor of shape (B, L). - cov (torch.Tensor): Covariate tensor of shape (B, N). - cov_t (torch.Tensor): Covariate time tensor of shape (B). - - Returns: - torch.Tensor: Logits of shape (B, L, vocab_size). - """ - B, L = x.size() - - cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1) - cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1)) - cov_embed = cov_encoded + cov_t_encoded - - token_embeddings = self.wte(x) - if self.training and self.token_pdrop > 0: - drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop - token_embeddings[drop_mask] = 0.0 - - pos_embeddings = self.age_encoder(t.float()) - seq_embed = self.drop(token_embeddings + pos_embeddings) - - t_i = t.unsqueeze(-1) - t_j = t.unsqueeze(1) - time_mask = (t_j < t_i) - padding_mask = (x != 0).unsqueeze(1) - combined_mask = time_mask & padding_mask - is_row_all_zero = ~combined_mask.any(dim=-1) - is_not_padding = (x != 0) - force_self_attention = is_row_all_zero & is_not_padding - combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True - - block_output = seq_embed - for block in self.blocks: - block_output = block(block_output, custom_mask=combined_mask) - - integrated_embed = torch.cat([block_output, cov_embed], dim=-1) - - final_output = self.ln_f(integrated_embed) - logits = self.head(final_output) - return logits - - def get_num_params(self) -> float: - """ - Returns the number of trainable parameters in the model in millions. - """ - return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 # ============================================================================= # 3. Loss Function diff --git a/train.py b/train.py index 1bcc023..e1f6e17 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt import json import argparse -from models import TimeAwareGPT2, CombinedLoss +from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss from utils import PatientEventDataset # --- Configuration --- @@ -25,6 +25,7 @@ class TrainConfig: n_head = 12 pdrop = 0.1 token_pdrop = 0.1 + model_name = 'TimeAwareGPT2' # Training parameters max_epoch = 200 @@ -59,6 +60,7 @@ def main(): parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') + parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.') args = parser.parse_args() @@ -76,10 +78,11 @@ def main(): config.pdrop = args.pdrop config.token_pdrop = args.token_pdrop config.betas = tuple(args.betas) + config.model_name = args.model - - model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" - checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}" + model_filename = f"best_model_{model_suffix}.pt" + checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" # --- 0. Save Configuration --- config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" @@ -105,7 +108,12 @@ def main(): # --- 2. Model, Optimizer, and Loss Initialization --- print(f"Initializing model on {config.device}...") - model = TimeAwareGPT2( + model_cls = { + 'TimeAwareGPT2': TimeAwareGPT2, + 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, + }[config.model_name] + + model = model_cls( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, @@ -235,7 +243,7 @@ def main(): print("\nTraining finished. No best model to save as validation loss never improved.") # --- Save losses to a txt file --- - losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" + losses_filename = f"losses_{model_suffix}.txt" with open(losses_filename, 'w') as f: f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") for i in range(len(train_losses_total)):