update models and training scripts

This commit is contained in:
2025-10-22 08:36:55 +08:00
parent e348086e52
commit bd88daa8c2
2 changed files with 56 additions and 90 deletions

126
models.py
View File

@@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple from typing import Tuple, Optional
# ============================================================================= # =============================================================================
# 1. Component Modules (Building Blocks) # 1. Component Modules (Building Blocks)
@@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module):
output[:, :, 1::2] = torch.sin(args) output[:, :, 1::2] = torch.sin(args)
return output return output
class LearnableAgeEncoding(nn.Module):
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
self.base_dim = base_dim
self.final_dim = final_dim or base_dim
hidden_dim = hidden_dim or base_dim
if hidden_dim <= 0:
raise ValueError("hidden_dim must be a positive integer.")
if self.final_dim <= 0:
raise ValueError("final_dim must be a positive integer.")
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
mlp_layers = [
nn.Linear(base_dim, hidden_dim),
nn.GELU(),
]
if dropout > 0.0:
mlp_layers.append(nn.Dropout(dropout))
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, t: torch.Tensor) -> torch.Tensor:
sin_embed = self.sinusoidal(t)
flat_embed = sin_embed.reshape(-1, self.base_dim)
projected = self.mlp(flat_embed)
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
class PiecewiseLinearEncoder(nn.Module): class PiecewiseLinearEncoder(nn.Module):
""" """
Encodes continuous variables using piecewise linear encoding. Encodes continuous variables using piecewise linear encoding.
@@ -287,94 +320,19 @@ class TimeAwareGPT2(nn.Module):
return x, t, final_logits return x, t, final_logits
class CovariateAwareGPT2(nn.Module):
"""
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, class TimeAwareGPT2Learnable(TimeAwareGPT2):
pdrop: float, token_pdrop: float, num_bins: int): """Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
"""
Initializes the CovariateAwareGPT2 model.
Args: def __init__(self, *args, **kwargs):
vocab_size (int): Size of the event vocabulary. super().__init__(*args, **kwargs)
n_embd (int): Embedding dimensionality. self.age_encoder = LearnableAgeEncoding(
n_layer (int): Number of transformer layers. base_dim=self.n_embd,
n_head (int): Number of attention heads. hidden_dim=2 * self.n_embd,
pdrop (float): Dropout probability for layers. final_dim=self.n_embd,
token_pdrop (float): Dropout probability for input token embeddings.
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
"""
super().__init__()
self.token_pdrop = token_pdrop
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.n_embd = n_embd
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
self.ln_f = nn.LayerNorm(2 * n_embd)
self.head = nn.Sequential(
nn.Linear(2 * n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, vocab_size)
) )
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CovariateAwareGPT2 model.
Args:
x (torch.Tensor): Event sequence tensor of shape (B, L).
t (torch.Tensor): Time sequence tensor of shape (B, L).
cov (torch.Tensor): Covariate tensor of shape (B, N).
cov_t (torch.Tensor): Covariate time tensor of shape (B).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = x.size()
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
cov_embed = cov_encoded + cov_t_encoded
token_embeddings = self.wte(x)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(t.float())
seq_embed = self.drop(token_embeddings + pos_embeddings)
t_i = t.unsqueeze(-1)
t_j = t.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (x != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (x != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
block_output = seq_embed
for block in self.blocks:
block_output = block(block_output, custom_mask=combined_mask)
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
final_output = self.ln_f(integrated_embed)
logits = self.head(final_output)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
# ============================================================================= # =============================================================================
# 3. Loss Function # 3. Loss Function

View File

@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
import json import json
import argparse import argparse
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
# --- Configuration --- # --- Configuration ---
@@ -25,6 +25,7 @@ class TrainConfig:
n_head = 12 n_head = 12
pdrop = 0.1 pdrop = 0.1
token_pdrop = 0.1 token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters # Training parameters
max_epoch = 200 max_epoch = 200
@@ -59,6 +60,7 @@ def main():
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
args = parser.parse_args() args = parser.parse_args()
@@ -76,10 +78,11 @@ def main():
config.pdrop = args.pdrop config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas) config.betas = tuple(args.betas)
config.model_name = args.model
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
# --- 0. Save Configuration --- # --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
@@ -105,7 +108,12 @@ def main():
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}[config.model_name]
model = model_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
n_layer=config.n_layer, n_layer=config.n_layer,
@@ -235,7 +243,7 @@ def main():
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file --- # --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f: with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)): for i in range(len(train_losses_total)):