feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
59
GEMINI.md
Normal file
59
GEMINI.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# DeepHealth Project
|
||||||
|
|
||||||
|
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
|
||||||
|
|
||||||
|
## 1. Project Structure
|
||||||
|
|
||||||
|
To maintain a clean and modular project, we adopt the following file organization:
|
||||||
|
|
||||||
|
DeepHealth/
|
||||||
|
|-tain.py
|
||||||
|
|-models.py
|
||||||
|
|-utils.py
|
||||||
|
|-data/
|
||||||
|
|-requirements.txt
|
||||||
|
|-README.md
|
||||||
|
|
||||||
|
|
||||||
|
### File Descriptions
|
||||||
|
|
||||||
|
* **`train.py`**:
|
||||||
|
* **Core training script**. It contains the control flow for the entire training process.
|
||||||
|
* Responsible for initializing the model, optimizer, DataLoader, etc.
|
||||||
|
* Executes the training and validation loops.
|
||||||
|
* Handles saving and loading checkpoints, logging, and other related tasks.
|
||||||
|
|
||||||
|
* **`models.py`**:
|
||||||
|
* **Model and Loss Function Definitions**. This file stores the architecture for all neural network models.
|
||||||
|
* All subclasses of `torch.nn.Module` should be defined in this file.
|
||||||
|
* Custom loss functions should also be implemented here.
|
||||||
|
|
||||||
|
* **`utils.py`**:
|
||||||
|
* **Utility Functions Module**. It contains reusable helper functions for the project.
|
||||||
|
* Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
|
||||||
|
|
||||||
|
* **`data/`**:
|
||||||
|
* **Data Storage Directory**. Used to store the datasets required for the project.
|
||||||
|
* `data/raw/` stores the original, unprocessed data.
|
||||||
|
* `data/processed/` stores data after it has been preprocessed.
|
||||||
|
|
||||||
|
* **`requirements.txt`**:
|
||||||
|
* **Project Dependencies**. Lists all the Python packages and their versions required to run this project.
|
||||||
|
|
||||||
|
* **`README.md`**:
|
||||||
|
* **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines.
|
||||||
|
|
||||||
|
## 2. Core Framework
|
||||||
|
|
||||||
|
* **Deep Learning Framework**: `PyTorch`
|
||||||
|
|
||||||
|
## 3. Coding Style
|
||||||
|
|
||||||
|
This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability.
|
||||||
|
|
||||||
|
Key features include:
|
||||||
|
* Using `yapf` or `black` for automatic code formatting.
|
||||||
|
* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`).
|
||||||
|
* Using Google-style docstrings.
|
||||||
|
|
||||||
|
Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html)
|
284
models.py
Normal file
284
models.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
import math
|
||||||
|
|
||||||
|
class CausalSelfAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||||||
|
super().__init__()
|
||||||
|
assert n_embd % n_head == 0
|
||||||
|
# key, query, value projections for all heads
|
||||||
|
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
||||||
|
# output projection
|
||||||
|
self.c_proj = nn.Linear(n_embd, n_embd)
|
||||||
|
# regularization
|
||||||
|
self.attn_dropout = nn.Dropout(pdrop)
|
||||||
|
self.resid_dropout = nn.Dropout(pdrop)
|
||||||
|
self.n_head = n_head
|
||||||
|
self.n_embd = n_embd
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||||
|
|
||||||
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||||
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||||
|
k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||||
|
q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||||
|
v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||||
|
|
||||||
|
# causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
|
||||||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
|
|
||||||
|
# Apply the time-based causal mask
|
||||||
|
att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf'))
|
||||||
|
|
||||||
|
att = F.softmax(att, dim=-1)
|
||||||
|
att = self.attn_dropout(att)
|
||||||
|
y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
|
||||||
|
y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side
|
||||||
|
|
||||||
|
# output projection
|
||||||
|
y = self.resid_dropout(self.c_proj(y))
|
||||||
|
return y
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
""" an unassuming Transformer block """
|
||||||
|
|
||||||
|
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_1 = nn.LayerNorm(n_embd)
|
||||||
|
self.attn = CausalSelfAttention(n_embd, n_head, pdrop)
|
||||||
|
self.ln_2 = nn.LayerNorm(n_embd)
|
||||||
|
self.mlp = nn.ModuleDict(dict(
|
||||||
|
c_fc = nn.Linear(n_embd, 4 * n_embd),
|
||||||
|
c_proj = nn.Linear(4 * n_embd, n_embd),
|
||||||
|
act = nn.GELU(),
|
||||||
|
dropout = nn.Dropout(pdrop),
|
||||||
|
))
|
||||||
|
m = self.mlp
|
||||||
|
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x + self.attn(self.ln_1(x), custom_mask=custom_mask)
|
||||||
|
x = x + self.mlpf(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AgeSinusoidalEncoding(nn.Module):
|
||||||
|
"""
|
||||||
|
Encodes age using sinusoidal functions, similar to positional encodings
|
||||||
|
in Transformers. This module creates a fixed-size embedding for an age
|
||||||
|
value given in days.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding_dim: int):
|
||||||
|
"""
|
||||||
|
Initializes the AgeSinusoidalEncoding module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim (int): The dimensionality of the output embedding.
|
||||||
|
Must be an even number.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If embedding_dim is not an even number.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if embedding_dim % 2 != 0:
|
||||||
|
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
|
||||||
|
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
# Pre-calculate the divisor term for the sinusoidal formula.
|
||||||
|
# The formula for the divisor is 10000^(2i/D), where D is the
|
||||||
|
# embedding_dim and i is the index for each pair of dimensions.
|
||||||
|
# i ranges from 0 to D/2 - 1.
|
||||||
|
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||||
|
divisor = torch.pow(10000, i / self.embedding_dim)
|
||||||
|
|
||||||
|
# Register the divisor as a non-trainable buffer. This ensures it is
|
||||||
|
# moved to the correct device (e.g., GPU) along with the model.
|
||||||
|
self.register_buffer('divisor', divisor)
|
||||||
|
|
||||||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the AgeSinusoidalEncoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
|
||||||
|
with dtype=torch.float32, representing age in days.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The encoded age tensor of shape
|
||||||
|
(batch_size, sequence_length, embedding_dim).
|
||||||
|
"""
|
||||||
|
# 1. Unit Conversion: Convert age from days to years.
|
||||||
|
# We use 365.25 to account for leap years.
|
||||||
|
t_years = t / 365.25
|
||||||
|
|
||||||
|
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
|
||||||
|
# The shapes are broadcast to (B, L, D/2).
|
||||||
|
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
|
||||||
|
# Divisor: (D/2) -> viewed as (1, 1, D/2)
|
||||||
|
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
||||||
|
|
||||||
|
# 3. Sinusoidal Application: Create the final output tensor.
|
||||||
|
# Initialize an empty tensor to store the embeddings.
|
||||||
|
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
||||||
|
|
||||||
|
# Assign cosine of the arguments to the even indices.
|
||||||
|
output[:, :, 0::2] = torch.cos(args)
|
||||||
|
|
||||||
|
# Assign sine of the arguments to the odd indices.
|
||||||
|
output[:, :, 1::2] = torch.sin(args)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
class TimeAwareGPT2(nn.Module):
|
||||||
|
"""
|
||||||
|
A time-aware GPT-2 model with custom temporal features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
||||||
|
super().__init__()
|
||||||
|
self.token_pdrop = token_pdrop
|
||||||
|
|
||||||
|
# Token and positional embeddings
|
||||||
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||||
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||||
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
|
||||||
|
# Final layer norm and linear head
|
||||||
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.n_embd = n_embd
|
||||||
|
|
||||||
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the TimeAwareGPT2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_seq (torch.Tensor): Token indices of shape (B, L).
|
||||||
|
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Logits of shape (B, L, vocab_size).
|
||||||
|
"""
|
||||||
|
B, L = event_seq.size()
|
||||||
|
|
||||||
|
# 1. Get token embeddings
|
||||||
|
token_embeddings = self.wte(event_seq)
|
||||||
|
|
||||||
|
# 2. Apply token dropout (only during training)
|
||||||
|
if self.training and self.token_pdrop > 0:
|
||||||
|
# Create a mask to randomly zero out entire token embedding vectors
|
||||||
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||||
|
token_embeddings[drop_mask] = 0.0
|
||||||
|
|
||||||
|
# 3. Get positional embeddings from time sequence
|
||||||
|
pos_embeddings = self.age_encoder(time_seq.float())
|
||||||
|
|
||||||
|
# 4. Combine embeddings and apply dropout
|
||||||
|
x = self.drop(token_embeddings + pos_embeddings)
|
||||||
|
|
||||||
|
# 5. Generate attention mask
|
||||||
|
# The attention mask combines two conditions:
|
||||||
|
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
|
||||||
|
# b) Padding mask: Do not attend to positions where the event token is 0.
|
||||||
|
|
||||||
|
# a) Time-based causal mask
|
||||||
|
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
||||||
|
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
||||||
|
time_mask = (t_j < t_i)
|
||||||
|
|
||||||
|
# b) Padding mask (prevents attending to key positions that are padding)
|
||||||
|
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
||||||
|
|
||||||
|
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
||||||
|
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
||||||
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
|
# 6. Pass through transformer blocks
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
|
# 7. Final layer norm and projection to vocab size
|
||||||
|
x = self.ln_f(x)
|
||||||
|
logits = self.head(x)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_num_params(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the number of trainable parameters in the model in millions.
|
||||||
|
"""
|
||||||
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
class CombinedLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||||
|
prediction and a survival analysis loss for event timing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ignored_token_ids: list[int]):
|
||||||
|
"""
|
||||||
|
Initializes the CombinedLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ignored_token_ids (list[int]): A list of event type IDs to be
|
||||||
|
excluded from all loss calculations.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.ignored_token_ids = ignored_token_ids
|
||||||
|
|
||||||
|
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calculates the combined cross-entropy and survival loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
|
||||||
|
x (torch.Tensor): Ground-truth event labels of shape (B, L).
|
||||||
|
t (torch.Tensor): True time duration for each event, shape (B, L).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
||||||
|
"""
|
||||||
|
# 1. Create a mask to filter out ignored token IDs from loss calculation.
|
||||||
|
# An element is True if the corresponding label in x is NOT in the ignored list.
|
||||||
|
mask = torch.ones_like(x, dtype=torch.bool)
|
||||||
|
for token_id in self.ignored_token_ids:
|
||||||
|
mask = mask & (x != token_id)
|
||||||
|
|
||||||
|
# If the mask is all False (all tokens are ignored), return zero for both losses.
|
||||||
|
if not mask.any():
|
||||||
|
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
||||||
|
|
||||||
|
# 2. Part 1: Cross-Entropy Loss (loss_ce)
|
||||||
|
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
|
||||||
|
logits_for_ce = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# Calculate per-element loss without reduction.
|
||||||
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||||
|
|
||||||
|
# Apply the mask and compute the mean of valid elements.
|
||||||
|
loss_ce = per_element_ce[mask].mean()
|
||||||
|
|
||||||
|
# 3. Part 2: Survival Loss (loss_survival)
|
||||||
|
# Calculate event intensity (lambda) as the sum of exponentiated logits.
|
||||||
|
intensity = torch.sum(torch.exp(logits), dim=2)
|
||||||
|
|
||||||
|
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
|
||||||
|
# We add a small epsilon for numerical stability with the log.
|
||||||
|
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||||||
|
|
||||||
|
# Apply the mask and compute the mean of valid elements.
|
||||||
|
loss_survival = per_element_survival[mask].mean()
|
||||||
|
|
||||||
|
return loss_ce, loss_survival
|
133
prepare_data.py
Normal file
133
prepare_data.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
label_files = 'labels.csv'
|
||||||
|
ukb_field_to_icd10_file = 'icd10_codes_mod.tsv'
|
||||||
|
ukb_basket_file = 'ukb_delphi.txt'
|
||||||
|
train_proportion = 0.8
|
||||||
|
output_prefix = 'ukb_real'
|
||||||
|
|
||||||
|
icdict = {}
|
||||||
|
icdcodes = []
|
||||||
|
with open(ukb_field_to_icd10_file) as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
icdict[parts[0]] = parts[5]
|
||||||
|
icdcodes.append(parts[5])
|
||||||
|
|
||||||
|
# Using enumerate for cleaner, safer label assignment starting from 0
|
||||||
|
label_dict = {}
|
||||||
|
with open(label_files) as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
label_dict[line.strip().split(' ')[0]] = i
|
||||||
|
|
||||||
|
icdict['f.31.0.0'] = "sex"
|
||||||
|
icdict['f.34.0.0'] = "YEAR"
|
||||||
|
icdict['f.52.0.0'] = "MONTH"
|
||||||
|
icdict['f.40000.0.0'] = "Death"
|
||||||
|
|
||||||
|
for j in range(17):
|
||||||
|
icdict[f'f.40005.{j}.0'] = f'cancer_date_{j}'
|
||||||
|
icdict[f'f.40006.{j}.0'] = f'cancer_type_{j}'
|
||||||
|
|
||||||
|
icdict['f.53.0.0'] = "assessment_date"
|
||||||
|
icdict['f.21001.0.0'] = "BMI"
|
||||||
|
icdict['f.1239.0.0'] = "smoking"
|
||||||
|
icdict['f.1558.0.0'] = "alcohol"
|
||||||
|
|
||||||
|
len_icd = len(icdcodes)
|
||||||
|
|
||||||
|
# Corrected typo 'aseessment_date' to 'assessment_date'
|
||||||
|
icdcodes.extend(['Death', 'assessment_date'] + [f'cancer_date_{j}' for j in range(17)])
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
ukb_iterator = pd.read_csv(ukb_basket_file, sep=',', chunksize=10000, index_col=0, low_memory=False)
|
||||||
|
|
||||||
|
for _, dd in tqdm.tqdm(enumerate(ukb_iterator)):
|
||||||
|
dd = dd.rename(columns=icdict)
|
||||||
|
dd.dropna(subset=['sex'], inplace=True)
|
||||||
|
dd['sex'] += 1
|
||||||
|
dd = dd[[col for col in dd.columns if not col.startswith('f.')]]
|
||||||
|
dd['dob'] = pd.to_datetime(dd[['YEAR', 'MONTH']].assign(DAY=1))
|
||||||
|
|
||||||
|
present_icdcodes = [c for c in icdcodes if c in dd.columns]
|
||||||
|
if present_icdcodes:
|
||||||
|
# Convert date columns to days from date of birth
|
||||||
|
date_cols = dd[present_icdcodes].apply(pd.to_datetime, format="%Y-%m-%d", errors='coerce')
|
||||||
|
date_cols_days = date_cols.sub(dd['dob'], axis=0)
|
||||||
|
dd[present_icdcodes] = date_cols_days.apply(lambda x: x.dt.days)
|
||||||
|
|
||||||
|
# Process ICD codes efficiently using melt
|
||||||
|
cols_to_process = [col for col in icdcodes[:len_icd + 1] if col in dd.columns]
|
||||||
|
if cols_to_process:
|
||||||
|
melted_df = dd.reset_index().melt(
|
||||||
|
id_vars=['f.eid'],
|
||||||
|
value_vars=cols_to_process,
|
||||||
|
var_name='event_code',
|
||||||
|
value_name='days'
|
||||||
|
)
|
||||||
|
melted_df.dropna(subset=['days'], inplace=True)
|
||||||
|
if not melted_df.empty:
|
||||||
|
melted_df['label'] = melted_df['event_code'].map(label_dict)
|
||||||
|
data_list.append(melted_df[['f.eid', 'days', 'label']].dropna().astype(int).to_numpy())
|
||||||
|
|
||||||
|
# Process sex
|
||||||
|
X = dd['sex'].reset_index().to_numpy().astype(int)
|
||||||
|
data_list.append(np.c_[X[:, 0], np.zeros(X.shape[0]), X[:, 1]].astype(int))
|
||||||
|
|
||||||
|
# Process cancer data efficiently using wide_to_long
|
||||||
|
df_res = dd.reset_index()
|
||||||
|
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
|
||||||
|
rename_dict.update({f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
|
||||||
|
df_renamed = df_res.rename(columns=rename_dict)
|
||||||
|
|
||||||
|
stubs_to_use = []
|
||||||
|
if any('cancerdate' in col for col in df_renamed.columns): stubs_to_use.append('cancerdate')
|
||||||
|
if any('cancertype' in col for col in df_renamed.columns): stubs_to_use.append('cancertype')
|
||||||
|
|
||||||
|
if len(stubs_to_use) == 2:
|
||||||
|
long_cancer = pd.wide_to_long(df_renamed,
|
||||||
|
stubnames=stubs_to_use,
|
||||||
|
i=['f.eid'],
|
||||||
|
j='cancer_num'
|
||||||
|
).dropna()
|
||||||
|
if not long_cancer.empty:
|
||||||
|
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(0, 3)
|
||||||
|
long_cancer['cancer_label'] = long_cancer['cancer'].map(label_dict)
|
||||||
|
cancer_array = long_cancer.reset_index()[['f.eid', 'cancerdate', 'cancer_label']].dropna().astype(int).to_numpy()
|
||||||
|
if cancer_array.size > 0:
|
||||||
|
data_list.append(cancer_array)
|
||||||
|
|
||||||
|
# Process BMI, smoking, and alcohol
|
||||||
|
dd_bmi = dd[['assessment_date', 'BMI']].dropna().reset_index()
|
||||||
|
if not dd_bmi.empty:
|
||||||
|
dd_bmi['bmi_status'] = np.select([dd_bmi['BMI'] > 28, dd_bmi['BMI'] > 22], [5, 4], default=3)
|
||||||
|
data_list.append(dd_bmi[['f.eid', 'assessment_date', 'bmi_status']].astype(int).to_numpy())
|
||||||
|
|
||||||
|
dd_sm = dd[['assessment_date', 'smoking']].dropna().reset_index()
|
||||||
|
dd_sm = dd_sm[dd_sm['smoking'] != -3]
|
||||||
|
if not dd_sm.empty:
|
||||||
|
dd_sm['smoking_status'] = np.select([dd_sm['smoking'] == 1, dd_sm['smoking'] == 2], [8, 7], default=6)
|
||||||
|
data_list.append(dd_sm[['f.eid', 'assessment_date', 'smoking_status']].astype(int).to_numpy())
|
||||||
|
|
||||||
|
dd_al = dd[['assessment_date', 'alcohol']].dropna().reset_index()
|
||||||
|
dd_al = dd_al[dd_al['alcohol'] != -3]
|
||||||
|
if not dd_al.empty:
|
||||||
|
dd_al['alcohol_status'] = np.select([dd_al['alcohol'] == 1, dd_al['alcohol'] < 4], [11, 10], default=9)
|
||||||
|
data_list.append(dd_al[['f.eid', 'assessment_date', 'alcohol_status']].astype(int).to_numpy())
|
||||||
|
|
||||||
|
data = np.vstack(data_list)
|
||||||
|
data = data[np.lexsort((data[:, 1], data[:, 2] == data[:, 2].max(), data[:, 0]))]
|
||||||
|
data = data[data[:, 1] >= 0]
|
||||||
|
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
|
||||||
|
data = data.astype(np.uint32)
|
||||||
|
data.tofile(output_prefix + '.bin')
|
||||||
|
|
||||||
|
# Correctly split train/validation sets
|
||||||
|
unique_ids = np.unique(data[:, 0])
|
||||||
|
split_id = unique_ids[int(len(unique_ids) * train_proportion)]
|
||||||
|
train_val_split = data[:, 0] <= split_id
|
||||||
|
|
||||||
|
data[train_val_split].tofile(output_prefix + '_train.bin')
|
||||||
|
data[~train_val_split].tofile(output_prefix + '_val.bin')
|
170
train.py
Normal file
170
train.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from models import TimeAwareGPT2, CombinedLoss
|
||||||
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
class TrainConfig:
|
||||||
|
# Data parameters
|
||||||
|
train_data_path = 'ukb_real_train.bin'
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
block_length = 256 # Sequence length
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
n_embd = 256
|
||||||
|
n_layer = 8
|
||||||
|
n_head = 8
|
||||||
|
pdrop = 0.1
|
||||||
|
token_pdrop = 0.1
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
max_epoch = 200
|
||||||
|
batch_size = 128
|
||||||
|
lr_initial = 6e-4
|
||||||
|
lr_final = 6e-5
|
||||||
|
warmup_epochs = 10
|
||||||
|
early_stopping_patience = 5
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
# 0 = padding, 1 = "no event"
|
||||||
|
ignored_token_ids = [0, 1]
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# --- Main Training Script ---
|
||||||
|
def main():
|
||||||
|
config = TrainConfig()
|
||||||
|
|
||||||
|
# --- 1. Data Loading ---
|
||||||
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
|
||||||
|
# Infer vocab_size from the data (max label + 1)
|
||||||
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||||
|
print(f"Inferred vocabulary size: {vocab_size}")
|
||||||
|
|
||||||
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||||
|
|
||||||
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
|
print(f"Initializing model on {config.device}...")
|
||||||
|
model = TimeAwareGPT2(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=config.n_embd,
|
||||||
|
n_layer=config.n_layer,
|
||||||
|
n_head=config.n_head,
|
||||||
|
pdrop=config.pdrop,
|
||||||
|
token_pdrop=config.token_pdrop
|
||||||
|
).to(config.device)
|
||||||
|
|
||||||
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
|
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||||
|
|
||||||
|
# --- 3. Training Loop ---
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
print("Starting training...")
|
||||||
|
for epoch in range(config.max_epoch):
|
||||||
|
# --- Learning Rate Scheduling ---
|
||||||
|
if epoch < config.warmup_epochs:
|
||||||
|
lr = config.lr_initial
|
||||||
|
else:
|
||||||
|
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||||
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# --- Training Phase ---
|
||||||
|
model.train()
|
||||||
|
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||||
|
train_steps = 0
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||||
|
for event_seq, time_seq in pbar:
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
# Prepare inputs and targets
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_loss_ce_acc += loss_ce.item()
|
||||||
|
train_loss_surv_acc += loss_survival.item()
|
||||||
|
train_steps += 1
|
||||||
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||||
|
|
||||||
|
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
||||||
|
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
||||||
|
|
||||||
|
# --- Validation Phase ---
|
||||||
|
model.eval()
|
||||||
|
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||||
|
val_steps = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||||
|
for event_seq, time_seq in pbar_val:
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
|
||||||
|
val_loss_ce_acc += loss_ce.item()
|
||||||
|
val_loss_surv_acc += loss_survival.item()
|
||||||
|
val_steps += 1
|
||||||
|
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||||
|
|
||||||
|
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||||
|
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||||
|
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1} Summary: \n"
|
||||||
|
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||||
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||||
|
f" Learning Rate: {lr:.6f}")
|
||||||
|
|
||||||
|
# --- Early Stopping Check ---
|
||||||
|
if total_val_loss < best_val_loss:
|
||||||
|
best_val_loss = total_val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.")
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||||
|
|
||||||
|
if patience_counter >= config.early_stopping_patience:
|
||||||
|
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
104
utils.py
Normal file
104
utils.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
A PyTorch Dataset for handling temporal sequences of patient events.
|
||||||
|
|
||||||
|
This class processes a raw NumPy array of patient records, groups them by
|
||||||
|
patient ID, and prepares them for training by imputing gaps, padding, or
|
||||||
|
truncating sequences to a fixed length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: np.ndarray, block_length: int):
|
||||||
|
"""
|
||||||
|
Initializes the dataset by pre-processing the patient event data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
|
||||||
|
The columns represent (patient_id, time_in_days, event_code).
|
||||||
|
block_length (int): The fixed length for the output sequences.
|
||||||
|
"""
|
||||||
|
self.block_length = block_length
|
||||||
|
|
||||||
|
# Group (time_in_days, event_code) pairs by patient_id.
|
||||||
|
# This pre-processing step allows for efficient lookups in __getitem__.
|
||||||
|
patient_events = defaultdict(list)
|
||||||
|
for patient_id, time, event in data:
|
||||||
|
patient_events[patient_id].append((time, event))
|
||||||
|
|
||||||
|
# Store a list of unique patient_ids to map indices to patients.
|
||||||
|
self.patient_ids = list(patient_events.keys())
|
||||||
|
self.patient_events = dict(patient_events)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the total number of unique patients in the dataset.
|
||||||
|
"""
|
||||||
|
return len(self.patient_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Retrieves, processes, and returns a single patient's event sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): The index of the patient to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||||
|
both of shape (block_length,).
|
||||||
|
"""
|
||||||
|
# 1. Retrieve and Sort
|
||||||
|
patient_id = self.patient_ids[idx]
|
||||||
|
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 2. Impute "No Event" Gaps
|
||||||
|
imputed_sequence = []
|
||||||
|
if not records:
|
||||||
|
# Handle cases with no records for a patient if necessary, though
|
||||||
|
# the constructor logic would typically prevent this.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
imputed_sequence.append(records[0])
|
||||||
|
for i in range(len(records) - 1):
|
||||||
|
prev_time, _ = records[i]
|
||||||
|
next_time, _ = records[i+1]
|
||||||
|
time_gap = next_time - prev_time
|
||||||
|
|
||||||
|
# If the gap is 5 years (1826 days) or more, insert "no event" records.
|
||||||
|
if time_gap >= 1826:
|
||||||
|
num_no_event_intervals = time_gap // 1826
|
||||||
|
for j in range(1, num_no_event_intervals + 1):
|
||||||
|
no_event_time = prev_time + j * 1826
|
||||||
|
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
|
||||||
|
|
||||||
|
imputed_sequence.append(records[i+1])
|
||||||
|
|
||||||
|
# 3. Adjust Sequence Length
|
||||||
|
seq_len = len(imputed_sequence)
|
||||||
|
|
||||||
|
if seq_len > self.block_length:
|
||||||
|
# If longer, randomly select a contiguous sub-sequence.
|
||||||
|
start_index = random.randint(0, seq_len - self.block_length)
|
||||||
|
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
|
||||||
|
elif seq_len < self.block_length:
|
||||||
|
# If shorter, pad the sequence at the end.
|
||||||
|
padding_needed = self.block_length - seq_len
|
||||||
|
# Use event_code=0 and time_in_days=36525 for padding.
|
||||||
|
padding = [(36525, 0)] * padding_needed
|
||||||
|
final_sequence = imputed_sequence + padding
|
||||||
|
else:
|
||||||
|
# If equal, use the sequence as is.
|
||||||
|
final_sequence = imputed_sequence
|
||||||
|
|
||||||
|
# 4. Return Tensors
|
||||||
|
# Separate the sequence into event codes and time, then convert to tensors.
|
||||||
|
event_codes = [item[1] for item in final_sequence]
|
||||||
|
time_stamps = [item[0] for item in final_sequence]
|
||||||
|
|
||||||
|
event_tensor = torch.tensor(event_codes, dtype=torch.long)
|
||||||
|
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||||
|
|
||||||
|
return event_tensor, time_tensor
|
Reference in New Issue
Block a user