update models and training scripts
This commit is contained in:
126
models.py
126
models.py
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 1. Component Modules (Building Blocks)
|
# 1. Component Modules (Building Blocks)
|
||||||
@@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module):
|
|||||||
output[:, :, 1::2] = torch.sin(args)
|
output[:, :, 1::2] = torch.sin(args)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LearnableAgeEncoding(nn.Module):
|
||||||
|
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
|
||||||
|
|
||||||
|
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.base_dim = base_dim
|
||||||
|
self.final_dim = final_dim or base_dim
|
||||||
|
|
||||||
|
hidden_dim = hidden_dim or base_dim
|
||||||
|
if hidden_dim <= 0:
|
||||||
|
raise ValueError("hidden_dim must be a positive integer.")
|
||||||
|
if self.final_dim <= 0:
|
||||||
|
raise ValueError("final_dim must be a positive integer.")
|
||||||
|
|
||||||
|
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
|
||||||
|
|
||||||
|
mlp_layers = [
|
||||||
|
nn.Linear(base_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
]
|
||||||
|
if dropout > 0.0:
|
||||||
|
mlp_layers.append(nn.Dropout(dropout))
|
||||||
|
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(*mlp_layers)
|
||||||
|
|
||||||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
sin_embed = self.sinusoidal(t)
|
||||||
|
flat_embed = sin_embed.reshape(-1, self.base_dim)
|
||||||
|
projected = self.mlp(flat_embed)
|
||||||
|
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
|
||||||
|
|
||||||
class PiecewiseLinearEncoder(nn.Module):
|
class PiecewiseLinearEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Encodes continuous variables using piecewise linear encoding.
|
Encodes continuous variables using piecewise linear encoding.
|
||||||
@@ -287,94 +320,19 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
|
|
||||||
return x, t, final_logits
|
return x, t, final_logits
|
||||||
|
|
||||||
class CovariateAwareGPT2(nn.Module):
|
|
||||||
"""
|
|
||||||
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
|
class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
||||||
pdrop: float, token_pdrop: float, num_bins: int):
|
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
|
||||||
"""
|
|
||||||
Initializes the CovariateAwareGPT2 model.
|
|
||||||
|
|
||||||
Args:
|
def __init__(self, *args, **kwargs):
|
||||||
vocab_size (int): Size of the event vocabulary.
|
super().__init__(*args, **kwargs)
|
||||||
n_embd (int): Embedding dimensionality.
|
self.age_encoder = LearnableAgeEncoding(
|
||||||
n_layer (int): Number of transformer layers.
|
base_dim=self.n_embd,
|
||||||
n_head (int): Number of attention heads.
|
hidden_dim=2 * self.n_embd,
|
||||||
pdrop (float): Dropout probability for layers.
|
final_dim=self.n_embd,
|
||||||
token_pdrop (float): Dropout probability for input token embeddings.
|
|
||||||
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.token_pdrop = token_pdrop
|
|
||||||
|
|
||||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
|
||||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
|
||||||
self.drop = nn.Dropout(pdrop)
|
|
||||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
|
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(2 * n_embd)
|
|
||||||
self.head = nn.Sequential(
|
|
||||||
nn.Linear(2 * n_embd, n_embd),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(n_embd, vocab_size)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Forward pass for the CovariateAwareGPT2 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Event sequence tensor of shape (B, L).
|
|
||||||
t (torch.Tensor): Time sequence tensor of shape (B, L).
|
|
||||||
cov (torch.Tensor): Covariate tensor of shape (B, N).
|
|
||||||
cov_t (torch.Tensor): Covariate time tensor of shape (B).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Logits of shape (B, L, vocab_size).
|
|
||||||
"""
|
|
||||||
B, L = x.size()
|
|
||||||
|
|
||||||
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
|
|
||||||
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
|
|
||||||
cov_embed = cov_encoded + cov_t_encoded
|
|
||||||
|
|
||||||
token_embeddings = self.wte(x)
|
|
||||||
if self.training and self.token_pdrop > 0:
|
|
||||||
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
|
||||||
token_embeddings[drop_mask] = 0.0
|
|
||||||
|
|
||||||
pos_embeddings = self.age_encoder(t.float())
|
|
||||||
seq_embed = self.drop(token_embeddings + pos_embeddings)
|
|
||||||
|
|
||||||
t_i = t.unsqueeze(-1)
|
|
||||||
t_j = t.unsqueeze(1)
|
|
||||||
time_mask = (t_j < t_i)
|
|
||||||
padding_mask = (x != 0).unsqueeze(1)
|
|
||||||
combined_mask = time_mask & padding_mask
|
|
||||||
is_row_all_zero = ~combined_mask.any(dim=-1)
|
|
||||||
is_not_padding = (x != 0)
|
|
||||||
force_self_attention = is_row_all_zero & is_not_padding
|
|
||||||
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
|
||||||
|
|
||||||
block_output = seq_embed
|
|
||||||
for block in self.blocks:
|
|
||||||
block_output = block(block_output, custom_mask=combined_mask)
|
|
||||||
|
|
||||||
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
|
|
||||||
|
|
||||||
final_output = self.ln_f(integrated_embed)
|
|
||||||
logits = self.head(final_output)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def get_num_params(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the number of trainable parameters in the model in millions.
|
|
||||||
"""
|
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 3. Loss Function
|
# 3. Loss Function
|
||||||
|
20
train.py
20
train.py
@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
|||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from models import TimeAwareGPT2, CombinedLoss
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
|
||||||
from utils import PatientEventDataset
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
@@ -25,6 +25,7 @@ class TrainConfig:
|
|||||||
n_head = 12
|
n_head = 12
|
||||||
pdrop = 0.1
|
pdrop = 0.1
|
||||||
token_pdrop = 0.1
|
token_pdrop = 0.1
|
||||||
|
model_name = 'TimeAwareGPT2'
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
max_epoch = 200
|
max_epoch = 200
|
||||||
@@ -59,6 +60,7 @@ def main():
|
|||||||
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||||
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||||
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||||
|
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -76,10 +78,11 @@ def main():
|
|||||||
config.pdrop = args.pdrop
|
config.pdrop = args.pdrop
|
||||||
config.token_pdrop = args.token_pdrop
|
config.token_pdrop = args.token_pdrop
|
||||||
config.betas = tuple(args.betas)
|
config.betas = tuple(args.betas)
|
||||||
|
config.model_name = args.model
|
||||||
|
|
||||||
|
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
|
||||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
model_filename = f"best_model_{model_suffix}.pt"
|
||||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
|
||||||
|
|
||||||
# --- 0. Save Configuration ---
|
# --- 0. Save Configuration ---
|
||||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||||
@@ -105,7 +108,12 @@ def main():
|
|||||||
|
|
||||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
print(f"Initializing model on {config.device}...")
|
print(f"Initializing model on {config.device}...")
|
||||||
model = TimeAwareGPT2(
|
model_cls = {
|
||||||
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
}[config.model_name]
|
||||||
|
|
||||||
|
model = model_cls(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_embd=config.n_embd,
|
n_embd=config.n_embd,
|
||||||
n_layer=config.n_layer,
|
n_layer=config.n_layer,
|
||||||
@@ -235,7 +243,7 @@ def main():
|
|||||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||||
|
|
||||||
# --- Save losses to a txt file ---
|
# --- Save losses to a txt file ---
|
||||||
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
|
losses_filename = f"losses_{model_suffix}.txt"
|
||||||
with open(losses_filename, 'w') as f:
|
with open(losses_filename, 'w') as f:
|
||||||
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||||
for i in range(len(train_losses_total)):
|
for i in range(len(train_losses_total)):
|
||||||
|
Reference in New Issue
Block a user