feat: Implement time-aware GPT-2 for patient event prediction

This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data.

Key components include:

- `models.py`:
  - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information.
  - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings.
  - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing.

- `utils.py`:
  - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation.

- `train.py`:
  - A comprehensive training script that initializes the model, data loaders, and loss function.
  - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss.

- `prepare_data.py`:
  - Script for preprocessing raw UK Biobank data into a format suitable for the model.

- `GEMINI.md`:
  - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
2025-10-16 14:21:36 +08:00
parent 1d4731ae42
commit 589d4d0bd2
5 changed files with 750 additions and 0 deletions

59
GEMINI.md Normal file
View File

@@ -0,0 +1,59 @@
# DeepHealth Project
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
## 1. Project Structure
To maintain a clean and modular project, we adopt the following file organization:
DeepHealth/
|-tain.py
|-models.py
|-utils.py
|-data/
|-requirements.txt
|-README.md
### File Descriptions
* **`train.py`**:
* **Core training script**. It contains the control flow for the entire training process.
* Responsible for initializing the model, optimizer, DataLoader, etc.
* Executes the training and validation loops.
* Handles saving and loading checkpoints, logging, and other related tasks.
* **`models.py`**:
* **Model and Loss Function Definitions**. This file stores the architecture for all neural network models.
* All subclasses of `torch.nn.Module` should be defined in this file.
* Custom loss functions should also be implemented here.
* **`utils.py`**:
* **Utility Functions Module**. It contains reusable helper functions for the project.
* Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
* **`data/`**:
* **Data Storage Directory**. Used to store the datasets required for the project.
* `data/raw/` stores the original, unprocessed data.
* `data/processed/` stores data after it has been preprocessed.
* **`requirements.txt`**:
* **Project Dependencies**. Lists all the Python packages and their versions required to run this project.
* **`README.md`**:
* **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines.
## 2. Core Framework
* **Deep Learning Framework**: `PyTorch`
## 3. Coding Style
This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability.
Key features include:
* Using `yapf` or `black` for automatic code formatting.
* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`).
* Using Google-style docstrings.
Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html)

284
models.py Normal file
View File

@@ -0,0 +1,284 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
import math
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
"""
def __init__(self, n_embd: int, n_head: int, pdrop: float):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
# output projection
self.c_proj = nn.Linear(n_embd, n_embd)
# regularization
self.attn_dropout = nn.Dropout(pdrop)
self.resid_dropout = nn.Dropout(pdrop)
self.n_head = n_head
self.n_embd = n_embd
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
# causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Apply the time-based causal mask
att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, n_embd: int, n_head: int, pdrop: float):
super().__init__()
self.ln_1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, pdrop)
self.ln_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(n_embd, 4 * n_embd),
c_proj = nn.Linear(4 * n_embd, n_embd),
act = nn.GELU(),
dropout = nn.Dropout(pdrop),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x), custom_mask=custom_mask)
x = x + self.mlpf(self.ln_2(x))
return x
class AgeSinusoidalEncoding(nn.Module):
"""
Encodes age using sinusoidal functions, similar to positional encodings
in Transformers. This module creates a fixed-size embedding for an age
value given in days.
"""
def __init__(self, embedding_dim: int):
"""
Initializes the AgeSinusoidalEncoding module.
Args:
embedding_dim (int): The dimensionality of the output embedding.
Must be an even number.
Raises:
ValueError: If embedding_dim is not an even number.
"""
super().__init__()
if embedding_dim % 2 != 0:
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
self.embedding_dim = embedding_dim
# Pre-calculate the divisor term for the sinusoidal formula.
# The formula for the divisor is 10000^(2i/D), where D is the
# embedding_dim and i is the index for each pair of dimensions.
# i ranges from 0 to D/2 - 1.
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.embedding_dim)
# Register the divisor as a non-trainable buffer. This ensures it is
# moved to the correct device (e.g., GPU) along with the model.
self.register_buffer('divisor', divisor)
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the AgeSinusoidalEncoding.
Args:
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
with dtype=torch.float32, representing age in days.
Returns:
torch.Tensor: The encoded age tensor of shape
(batch_size, sequence_length, embedding_dim).
"""
# 1. Unit Conversion: Convert age from days to years.
# We use 365.25 to account for leap years.
t_years = t / 365.25
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
# The shapes are broadcast to (B, L, D/2).
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
# Divisor: (D/2) -> viewed as (1, 1, D/2)
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
# 3. Sinusoidal Application: Create the final output tensor.
# Initialize an empty tensor to store the embeddings.
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
# Assign cosine of the arguments to the even indices.
output[:, :, 0::2] = torch.cos(args)
# Assign sine of the arguments to the odd indices.
output[:, :, 1::2] = torch.sin(args)
return output
class TimeAwareGPT2(nn.Module):
"""
A time-aware GPT-2 model with custom temporal features.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
super().__init__()
self.token_pdrop = token_pdrop
# Token and positional embeddings
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
# Transformer blocks
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
# Final layer norm and linear head
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.n_embd = n_embd
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the TimeAwareGPT2 model.
Args:
event_seq (torch.Tensor): Token indices of shape (B, L).
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = event_seq.size()
# 1. Get token embeddings
token_embeddings = self.wte(event_seq)
# 2. Apply token dropout (only during training)
if self.training and self.token_pdrop > 0:
# Create a mask to randomly zero out entire token embedding vectors
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
# 3. Get positional embeddings from time sequence
pos_embeddings = self.age_encoder(time_seq.float())
# 4. Combine embeddings and apply dropout
x = self.drop(token_embeddings + pos_embeddings)
# 5. Generate attention mask
# The attention mask combines two conditions:
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
# b) Padding mask: Do not attend to positions where the event token is 0.
# a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j < t_i)
# b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
# Combine the masks. A position (j) can be attended to by a query (i) only if
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
combined_mask = time_mask & padding_mask
# 6. Pass through transformer blocks
for block in self.blocks:
x = block(x, custom_mask=combined_mask)
# 7. Final layer norm and projection to vocab size
x = self.ln_f(x)
logits = self.head(x)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
class CombinedLoss(nn.Module):
"""
Computes a two-part loss: a standard cross-entropy loss for event type
prediction and a survival analysis loss for event timing.
"""
def __init__(self, ignored_token_ids: list[int]):
"""
Initializes the CombinedLoss module.
Args:
ignored_token_ids (list[int]): A list of event type IDs to be
excluded from all loss calculations.
"""
super().__init__()
self.ignored_token_ids = ignored_token_ids
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates the combined cross-entropy and survival loss.
Args:
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
x (torch.Tensor): Ground-truth event labels of shape (B, L).
t (torch.Tensor): True time duration for each event, shape (B, L).
Returns:
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
"""
# 1. Create a mask to filter out ignored token IDs from loss calculation.
# An element is True if the corresponding label in x is NOT in the ignored list.
mask = torch.ones_like(x, dtype=torch.bool)
for token_id in self.ignored_token_ids:
mask = mask & (x != token_id)
# If the mask is all False (all tokens are ignored), return zero for both losses.
if not mask.any():
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
# 2. Part 1: Cross-Entropy Loss (loss_ce)
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
logits_for_ce = logits.permute(0, 2, 1)
# Calculate per-element loss without reduction.
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
# Apply the mask and compute the mean of valid elements.
loss_ce = per_element_ce[mask].mean()
# 3. Part 2: Survival Loss (loss_survival)
# Calculate event intensity (lambda) as the sum of exponentiated logits.
intensity = torch.sum(torch.exp(logits), dim=2)
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
# We add a small epsilon for numerical stability with the log.
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
# Apply the mask and compute the mean of valid elements.
loss_survival = per_element_survival[mask].mean()
return loss_ce, loss_survival

133
prepare_data.py Normal file
View File

@@ -0,0 +1,133 @@
import pandas as pd
import tqdm
import numpy as np
label_files = 'labels.csv'
ukb_field_to_icd10_file = 'icd10_codes_mod.tsv'
ukb_basket_file = 'ukb_delphi.txt'
train_proportion = 0.8
output_prefix = 'ukb_real'
icdict = {}
icdcodes = []
with open(ukb_field_to_icd10_file) as f:
for line in f:
parts = line.strip().split()
icdict[parts[0]] = parts[5]
icdcodes.append(parts[5])
# Using enumerate for cleaner, safer label assignment starting from 0
label_dict = {}
with open(label_files) as f:
for i, line in enumerate(f):
label_dict[line.strip().split(' ')[0]] = i
icdict['f.31.0.0'] = "sex"
icdict['f.34.0.0'] = "YEAR"
icdict['f.52.0.0'] = "MONTH"
icdict['f.40000.0.0'] = "Death"
for j in range(17):
icdict[f'f.40005.{j}.0'] = f'cancer_date_{j}'
icdict[f'f.40006.{j}.0'] = f'cancer_type_{j}'
icdict['f.53.0.0'] = "assessment_date"
icdict['f.21001.0.0'] = "BMI"
icdict['f.1239.0.0'] = "smoking"
icdict['f.1558.0.0'] = "alcohol"
len_icd = len(icdcodes)
# Corrected typo 'aseessment_date' to 'assessment_date'
icdcodes.extend(['Death', 'assessment_date'] + [f'cancer_date_{j}' for j in range(17)])
data_list = []
ukb_iterator = pd.read_csv(ukb_basket_file, sep=',', chunksize=10000, index_col=0, low_memory=False)
for _, dd in tqdm.tqdm(enumerate(ukb_iterator)):
dd = dd.rename(columns=icdict)
dd.dropna(subset=['sex'], inplace=True)
dd['sex'] += 1
dd = dd[[col for col in dd.columns if not col.startswith('f.')]]
dd['dob'] = pd.to_datetime(dd[['YEAR', 'MONTH']].assign(DAY=1))
present_icdcodes = [c for c in icdcodes if c in dd.columns]
if present_icdcodes:
# Convert date columns to days from date of birth
date_cols = dd[present_icdcodes].apply(pd.to_datetime, format="%Y-%m-%d", errors='coerce')
date_cols_days = date_cols.sub(dd['dob'], axis=0)
dd[present_icdcodes] = date_cols_days.apply(lambda x: x.dt.days)
# Process ICD codes efficiently using melt
cols_to_process = [col for col in icdcodes[:len_icd + 1] if col in dd.columns]
if cols_to_process:
melted_df = dd.reset_index().melt(
id_vars=['f.eid'],
value_vars=cols_to_process,
var_name='event_code',
value_name='days'
)
melted_df.dropna(subset=['days'], inplace=True)
if not melted_df.empty:
melted_df['label'] = melted_df['event_code'].map(label_dict)
data_list.append(melted_df[['f.eid', 'days', 'label']].dropna().astype(int).to_numpy())
# Process sex
X = dd['sex'].reset_index().to_numpy().astype(int)
data_list.append(np.c_[X[:, 0], np.zeros(X.shape[0]), X[:, 1]].astype(int))
# Process cancer data efficiently using wide_to_long
df_res = dd.reset_index()
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
rename_dict.update({f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
df_renamed = df_res.rename(columns=rename_dict)
stubs_to_use = []
if any('cancerdate' in col for col in df_renamed.columns): stubs_to_use.append('cancerdate')
if any('cancertype' in col for col in df_renamed.columns): stubs_to_use.append('cancertype')
if len(stubs_to_use) == 2:
long_cancer = pd.wide_to_long(df_renamed,
stubnames=stubs_to_use,
i=['f.eid'],
j='cancer_num'
).dropna()
if not long_cancer.empty:
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(0, 3)
long_cancer['cancer_label'] = long_cancer['cancer'].map(label_dict)
cancer_array = long_cancer.reset_index()[['f.eid', 'cancerdate', 'cancer_label']].dropna().astype(int).to_numpy()
if cancer_array.size > 0:
data_list.append(cancer_array)
# Process BMI, smoking, and alcohol
dd_bmi = dd[['assessment_date', 'BMI']].dropna().reset_index()
if not dd_bmi.empty:
dd_bmi['bmi_status'] = np.select([dd_bmi['BMI'] > 28, dd_bmi['BMI'] > 22], [5, 4], default=3)
data_list.append(dd_bmi[['f.eid', 'assessment_date', 'bmi_status']].astype(int).to_numpy())
dd_sm = dd[['assessment_date', 'smoking']].dropna().reset_index()
dd_sm = dd_sm[dd_sm['smoking'] != -3]
if not dd_sm.empty:
dd_sm['smoking_status'] = np.select([dd_sm['smoking'] == 1, dd_sm['smoking'] == 2], [8, 7], default=6)
data_list.append(dd_sm[['f.eid', 'assessment_date', 'smoking_status']].astype(int).to_numpy())
dd_al = dd[['assessment_date', 'alcohol']].dropna().reset_index()
dd_al = dd_al[dd_al['alcohol'] != -3]
if not dd_al.empty:
dd_al['alcohol_status'] = np.select([dd_al['alcohol'] == 1, dd_al['alcohol'] < 4], [11, 10], default=9)
data_list.append(dd_al[['f.eid', 'assessment_date', 'alcohol_status']].astype(int).to_numpy())
data = np.vstack(data_list)
data = data[np.lexsort((data[:, 1], data[:, 2] == data[:, 2].max(), data[:, 0]))]
data = data[data[:, 1] >= 0]
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
data = data.astype(np.uint32)
data.tofile(output_prefix + '.bin')
# Correctly split train/validation sets
unique_ids = np.unique(data[:, 0])
split_id = unique_ids[int(len(unique_ids) * train_proportion)]
train_val_split = data[:, 0] <= split_id
data[train_val_split].tofile(output_prefix + '_train.bin')
data[~train_val_split].tofile(output_prefix + '_val.bin')

170
train.py Normal file
View File

@@ -0,0 +1,170 @@
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 256 # Sequence length
# Model parameters
n_embd = 256
n_layer = 8
n_head = 8
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1]
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
print("Starting training...")
for epoch in range(config.max_epoch):
# --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Phase ---
model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
train_steps += 1
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
# --- Validation Phase ---
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
# --- Early Stopping Check ---
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.")
else:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
if __name__ == '__main__':
main()

104
utils.py Normal file
View File

@@ -0,0 +1,104 @@
import torch
import numpy as np
import random
from collections import defaultdict
class PatientEventDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset for handling temporal sequences of patient events.
This class processes a raw NumPy array of patient records, groups them by
patient ID, and prepares them for training by imputing gaps, padding, or
truncating sequences to a fixed length.
"""
def __init__(self, data: np.ndarray, block_length: int):
"""
Initializes the dataset by pre-processing the patient event data.
Args:
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
The columns represent (patient_id, time_in_days, event_code).
block_length (int): The fixed length for the output sequences.
"""
self.block_length = block_length
# Group (time_in_days, event_code) pairs by patient_id.
# This pre-processing step allows for efficient lookups in __getitem__.
patient_events = defaultdict(list)
for patient_id, time, event in data:
patient_events[patient_id].append((time, event))
# Store a list of unique patient_ids to map indices to patients.
self.patient_ids = list(patient_events.keys())
self.patient_events = dict(patient_events)
def __len__(self) -> int:
"""
Returns the total number of unique patients in the dataset.
"""
return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves, processes, and returns a single patient's event sequence.
Args:
idx (int): The index of the patient to retrieve.
Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence),
both of shape (block_length,).
"""
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
# 2. Impute "No Event" Gaps
imputed_sequence = []
if not records:
# Handle cases with no records for a patient if necessary, though
# the constructor logic would typically prevent this.
pass
else:
imputed_sequence.append(records[0])
for i in range(len(records) - 1):
prev_time, _ = records[i]
next_time, _ = records[i+1]
time_gap = next_time - prev_time
# If the gap is 5 years (1826 days) or more, insert "no event" records.
if time_gap >= 1826:
num_no_event_intervals = time_gap // 1826
for j in range(1, num_no_event_intervals + 1):
no_event_time = prev_time + j * 1826
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
imputed_sequence.append(records[i+1])
# 3. Adjust Sequence Length
seq_len = len(imputed_sequence)
if seq_len > self.block_length:
# If longer, randomly select a contiguous sub-sequence.
start_index = random.randint(0, seq_len - self.block_length)
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
elif seq_len < self.block_length:
# If shorter, pad the sequence at the end.
padding_needed = self.block_length - seq_len
# Use event_code=0 and time_in_days=36525 for padding.
padding = [(36525, 0)] * padding_needed
final_sequence = imputed_sequence + padding
else:
# If equal, use the sequence as is.
final_sequence = imputed_sequence
# 4. Return Tensors
# Separate the sequence into event codes and time, then convert to tensors.
event_codes = [item[1] for item in final_sequence]
time_stamps = [item[0] for item in final_sequence]
event_tensor = torch.tensor(event_codes, dtype=torch.long)
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor