feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
59
GEMINI.md
Normal file
59
GEMINI.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# DeepHealth Project
|
||||
|
||||
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
|
||||
|
||||
## 1. Project Structure
|
||||
|
||||
To maintain a clean and modular project, we adopt the following file organization:
|
||||
|
||||
DeepHealth/
|
||||
|-tain.py
|
||||
|-models.py
|
||||
|-utils.py
|
||||
|-data/
|
||||
|-requirements.txt
|
||||
|-README.md
|
||||
|
||||
|
||||
### File Descriptions
|
||||
|
||||
* **`train.py`**:
|
||||
* **Core training script**. It contains the control flow for the entire training process.
|
||||
* Responsible for initializing the model, optimizer, DataLoader, etc.
|
||||
* Executes the training and validation loops.
|
||||
* Handles saving and loading checkpoints, logging, and other related tasks.
|
||||
|
||||
* **`models.py`**:
|
||||
* **Model and Loss Function Definitions**. This file stores the architecture for all neural network models.
|
||||
* All subclasses of `torch.nn.Module` should be defined in this file.
|
||||
* Custom loss functions should also be implemented here.
|
||||
|
||||
* **`utils.py`**:
|
||||
* **Utility Functions Module**. It contains reusable helper functions for the project.
|
||||
* Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
|
||||
|
||||
* **`data/`**:
|
||||
* **Data Storage Directory**. Used to store the datasets required for the project.
|
||||
* `data/raw/` stores the original, unprocessed data.
|
||||
* `data/processed/` stores data after it has been preprocessed.
|
||||
|
||||
* **`requirements.txt`**:
|
||||
* **Project Dependencies**. Lists all the Python packages and their versions required to run this project.
|
||||
|
||||
* **`README.md`**:
|
||||
* **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines.
|
||||
|
||||
## 2. Core Framework
|
||||
|
||||
* **Deep Learning Framework**: `PyTorch`
|
||||
|
||||
## 3. Coding Style
|
||||
|
||||
This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability.
|
||||
|
||||
Key features include:
|
||||
* Using `yapf` or `black` for automatic code formatting.
|
||||
* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`).
|
||||
* Using Google-style docstrings.
|
||||
|
||||
Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html)
|
284
models.py
Normal file
284
models.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Tuple
|
||||
import math
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||
"""
|
||||
|
||||
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||||
super().__init__()
|
||||
assert n_embd % n_head == 0
|
||||
# key, query, value projections for all heads
|
||||
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(n_embd, n_embd)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(pdrop)
|
||||
self.resid_dropout = nn.Dropout(pdrop)
|
||||
self.n_head = n_head
|
||||
self.n_embd = n_embd
|
||||
|
||||
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||
B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||
k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||
q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||
v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
|
||||
# Apply the time-based causal mask
|
||||
att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf'))
|
||||
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
|
||||
class Block(nn.Module):
|
||||
""" an unassuming Transformer block """
|
||||
|
||||
def __init__(self, n_embd: int, n_head: int, pdrop: float):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(n_embd)
|
||||
self.attn = CausalSelfAttention(n_embd, n_head, pdrop)
|
||||
self.ln_2 = nn.LayerNorm(n_embd)
|
||||
self.mlp = nn.ModuleDict(dict(
|
||||
c_fc = nn.Linear(n_embd, 4 * n_embd),
|
||||
c_proj = nn.Linear(4 * n_embd, n_embd),
|
||||
act = nn.GELU(),
|
||||
dropout = nn.Dropout(pdrop),
|
||||
))
|
||||
m = self.mlp
|
||||
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
|
||||
|
||||
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(self.ln_1(x), custom_mask=custom_mask)
|
||||
x = x + self.mlpf(self.ln_2(x))
|
||||
return x
|
||||
|
||||
class AgeSinusoidalEncoding(nn.Module):
|
||||
"""
|
||||
Encodes age using sinusoidal functions, similar to positional encodings
|
||||
in Transformers. This module creates a fixed-size embedding for an age
|
||||
value given in days.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
"""
|
||||
Initializes the AgeSinusoidalEncoding module.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the output embedding.
|
||||
Must be an even number.
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding_dim is not an even number.
|
||||
"""
|
||||
super().__init__()
|
||||
if embedding_dim % 2 != 0:
|
||||
raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}")
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# Pre-calculate the divisor term for the sinusoidal formula.
|
||||
# The formula for the divisor is 10000^(2i/D), where D is the
|
||||
# embedding_dim and i is the index for each pair of dimensions.
|
||||
# i ranges from 0 to D/2 - 1.
|
||||
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
divisor = torch.pow(10000, i / self.embedding_dim)
|
||||
|
||||
# Register the divisor as a non-trainable buffer. This ensures it is
|
||||
# moved to the correct device (e.g., GPU) along with the model.
|
||||
self.register_buffer('divisor', divisor)
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the AgeSinusoidalEncoding.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): A tensor of shape (batch_size, sequence_length)
|
||||
with dtype=torch.float32, representing age in days.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The encoded age tensor of shape
|
||||
(batch_size, sequence_length, embedding_dim).
|
||||
"""
|
||||
# 1. Unit Conversion: Convert age from days to years.
|
||||
# We use 365.25 to account for leap years.
|
||||
t_years = t / 365.25
|
||||
|
||||
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
|
||||
# The shapes are broadcast to (B, L, D/2).
|
||||
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
|
||||
# Divisor: (D/2) -> viewed as (1, 1, D/2)
|
||||
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
||||
|
||||
# 3. Sinusoidal Application: Create the final output tensor.
|
||||
# Initialize an empty tensor to store the embeddings.
|
||||
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
||||
|
||||
# Assign cosine of the arguments to the even indices.
|
||||
output[:, :, 0::2] = torch.cos(args)
|
||||
|
||||
# Assign sine of the arguments to the odd indices.
|
||||
output[:, :, 1::2] = torch.sin(args)
|
||||
|
||||
return output
|
||||
|
||||
class TimeAwareGPT2(nn.Module):
|
||||
"""
|
||||
A time-aware GPT-2 model with custom temporal features.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
||||
super().__init__()
|
||||
self.token_pdrop = token_pdrop
|
||||
|
||||
# Token and positional embeddings
|
||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||
self.drop = nn.Dropout(pdrop)
|
||||
|
||||
# Transformer blocks
|
||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||
|
||||
# Final layer norm and linear head
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
|
||||
self.n_embd = n_embd
|
||||
|
||||
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the TimeAwareGPT2 model.
|
||||
|
||||
Args:
|
||||
event_seq (torch.Tensor): Token indices of shape (B, L).
|
||||
time_seq (torch.Tensor): Timestamps for each event of shape (B, L).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits of shape (B, L, vocab_size).
|
||||
"""
|
||||
B, L = event_seq.size()
|
||||
|
||||
# 1. Get token embeddings
|
||||
token_embeddings = self.wte(event_seq)
|
||||
|
||||
# 2. Apply token dropout (only during training)
|
||||
if self.training and self.token_pdrop > 0:
|
||||
# Create a mask to randomly zero out entire token embedding vectors
|
||||
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||
token_embeddings[drop_mask] = 0.0
|
||||
|
||||
# 3. Get positional embeddings from time sequence
|
||||
pos_embeddings = self.age_encoder(time_seq.float())
|
||||
|
||||
# 4. Combine embeddings and apply dropout
|
||||
x = self.drop(token_embeddings + pos_embeddings)
|
||||
|
||||
# 5. Generate attention mask
|
||||
# The attention mask combines two conditions:
|
||||
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
|
||||
# b) Padding mask: Do not attend to positions where the event token is 0.
|
||||
|
||||
# a) Time-based causal mask
|
||||
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
||||
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
||||
time_mask = (t_j < t_i)
|
||||
|
||||
# b) Padding mask (prevents attending to key positions that are padding)
|
||||
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
||||
|
||||
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
||||
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
||||
combined_mask = time_mask & padding_mask
|
||||
|
||||
# 6. Pass through transformer blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, custom_mask=combined_mask)
|
||||
|
||||
# 7. Final layer norm and projection to vocab size
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
|
||||
return logits
|
||||
|
||||
def get_num_params(self) -> float:
|
||||
"""
|
||||
Returns the number of trainable parameters in the model in millions.
|
||||
"""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""
|
||||
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||
prediction and a survival analysis loss for event timing.
|
||||
"""
|
||||
|
||||
def __init__(self, ignored_token_ids: list[int]):
|
||||
"""
|
||||
Initializes the CombinedLoss module.
|
||||
|
||||
Args:
|
||||
ignored_token_ids (list[int]): A list of event type IDs to be
|
||||
excluded from all loss calculations.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ignored_token_ids = ignored_token_ids
|
||||
|
||||
def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates the combined cross-entropy and survival loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): Raw model outputs of shape (B, L, N).
|
||||
x (torch.Tensor): Ground-truth event labels of shape (B, L).
|
||||
t (torch.Tensor): True time duration for each event, shape (B, L).
|
||||
|
||||
Returns:
|
||||
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
||||
"""
|
||||
# 1. Create a mask to filter out ignored token IDs from loss calculation.
|
||||
# An element is True if the corresponding label in x is NOT in the ignored list.
|
||||
mask = torch.ones_like(x, dtype=torch.bool)
|
||||
for token_id in self.ignored_token_ids:
|
||||
mask = mask & (x != token_id)
|
||||
|
||||
# If the mask is all False (all tokens are ignored), return zero for both losses.
|
||||
if not mask.any():
|
||||
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
||||
|
||||
# 2. Part 1: Cross-Entropy Loss (loss_ce)
|
||||
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
|
||||
logits_for_ce = logits.permute(0, 2, 1)
|
||||
|
||||
# Calculate per-element loss without reduction.
|
||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||
|
||||
# Apply the mask and compute the mean of valid elements.
|
||||
loss_ce = per_element_ce[mask].mean()
|
||||
|
||||
# 3. Part 2: Survival Loss (loss_survival)
|
||||
# Calculate event intensity (lambda) as the sum of exponentiated logits.
|
||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
||||
|
||||
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
|
||||
# We add a small epsilon for numerical stability with the log.
|
||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||||
|
||||
# Apply the mask and compute the mean of valid elements.
|
||||
loss_survival = per_element_survival[mask].mean()
|
||||
|
||||
return loss_ce, loss_survival
|
133
prepare_data.py
Normal file
133
prepare_data.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
import numpy as np
|
||||
|
||||
label_files = 'labels.csv'
|
||||
ukb_field_to_icd10_file = 'icd10_codes_mod.tsv'
|
||||
ukb_basket_file = 'ukb_delphi.txt'
|
||||
train_proportion = 0.8
|
||||
output_prefix = 'ukb_real'
|
||||
|
||||
icdict = {}
|
||||
icdcodes = []
|
||||
with open(ukb_field_to_icd10_file) as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
icdict[parts[0]] = parts[5]
|
||||
icdcodes.append(parts[5])
|
||||
|
||||
# Using enumerate for cleaner, safer label assignment starting from 0
|
||||
label_dict = {}
|
||||
with open(label_files) as f:
|
||||
for i, line in enumerate(f):
|
||||
label_dict[line.strip().split(' ')[0]] = i
|
||||
|
||||
icdict['f.31.0.0'] = "sex"
|
||||
icdict['f.34.0.0'] = "YEAR"
|
||||
icdict['f.52.0.0'] = "MONTH"
|
||||
icdict['f.40000.0.0'] = "Death"
|
||||
|
||||
for j in range(17):
|
||||
icdict[f'f.40005.{j}.0'] = f'cancer_date_{j}'
|
||||
icdict[f'f.40006.{j}.0'] = f'cancer_type_{j}'
|
||||
|
||||
icdict['f.53.0.0'] = "assessment_date"
|
||||
icdict['f.21001.0.0'] = "BMI"
|
||||
icdict['f.1239.0.0'] = "smoking"
|
||||
icdict['f.1558.0.0'] = "alcohol"
|
||||
|
||||
len_icd = len(icdcodes)
|
||||
|
||||
# Corrected typo 'aseessment_date' to 'assessment_date'
|
||||
icdcodes.extend(['Death', 'assessment_date'] + [f'cancer_date_{j}' for j in range(17)])
|
||||
|
||||
data_list = []
|
||||
ukb_iterator = pd.read_csv(ukb_basket_file, sep=',', chunksize=10000, index_col=0, low_memory=False)
|
||||
|
||||
for _, dd in tqdm.tqdm(enumerate(ukb_iterator)):
|
||||
dd = dd.rename(columns=icdict)
|
||||
dd.dropna(subset=['sex'], inplace=True)
|
||||
dd['sex'] += 1
|
||||
dd = dd[[col for col in dd.columns if not col.startswith('f.')]]
|
||||
dd['dob'] = pd.to_datetime(dd[['YEAR', 'MONTH']].assign(DAY=1))
|
||||
|
||||
present_icdcodes = [c for c in icdcodes if c in dd.columns]
|
||||
if present_icdcodes:
|
||||
# Convert date columns to days from date of birth
|
||||
date_cols = dd[present_icdcodes].apply(pd.to_datetime, format="%Y-%m-%d", errors='coerce')
|
||||
date_cols_days = date_cols.sub(dd['dob'], axis=0)
|
||||
dd[present_icdcodes] = date_cols_days.apply(lambda x: x.dt.days)
|
||||
|
||||
# Process ICD codes efficiently using melt
|
||||
cols_to_process = [col for col in icdcodes[:len_icd + 1] if col in dd.columns]
|
||||
if cols_to_process:
|
||||
melted_df = dd.reset_index().melt(
|
||||
id_vars=['f.eid'],
|
||||
value_vars=cols_to_process,
|
||||
var_name='event_code',
|
||||
value_name='days'
|
||||
)
|
||||
melted_df.dropna(subset=['days'], inplace=True)
|
||||
if not melted_df.empty:
|
||||
melted_df['label'] = melted_df['event_code'].map(label_dict)
|
||||
data_list.append(melted_df[['f.eid', 'days', 'label']].dropna().astype(int).to_numpy())
|
||||
|
||||
# Process sex
|
||||
X = dd['sex'].reset_index().to_numpy().astype(int)
|
||||
data_list.append(np.c_[X[:, 0], np.zeros(X.shape[0]), X[:, 1]].astype(int))
|
||||
|
||||
# Process cancer data efficiently using wide_to_long
|
||||
df_res = dd.reset_index()
|
||||
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
|
||||
rename_dict.update({f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
|
||||
df_renamed = df_res.rename(columns=rename_dict)
|
||||
|
||||
stubs_to_use = []
|
||||
if any('cancerdate' in col for col in df_renamed.columns): stubs_to_use.append('cancerdate')
|
||||
if any('cancertype' in col for col in df_renamed.columns): stubs_to_use.append('cancertype')
|
||||
|
||||
if len(stubs_to_use) == 2:
|
||||
long_cancer = pd.wide_to_long(df_renamed,
|
||||
stubnames=stubs_to_use,
|
||||
i=['f.eid'],
|
||||
j='cancer_num'
|
||||
).dropna()
|
||||
if not long_cancer.empty:
|
||||
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(0, 3)
|
||||
long_cancer['cancer_label'] = long_cancer['cancer'].map(label_dict)
|
||||
cancer_array = long_cancer.reset_index()[['f.eid', 'cancerdate', 'cancer_label']].dropna().astype(int).to_numpy()
|
||||
if cancer_array.size > 0:
|
||||
data_list.append(cancer_array)
|
||||
|
||||
# Process BMI, smoking, and alcohol
|
||||
dd_bmi = dd[['assessment_date', 'BMI']].dropna().reset_index()
|
||||
if not dd_bmi.empty:
|
||||
dd_bmi['bmi_status'] = np.select([dd_bmi['BMI'] > 28, dd_bmi['BMI'] > 22], [5, 4], default=3)
|
||||
data_list.append(dd_bmi[['f.eid', 'assessment_date', 'bmi_status']].astype(int).to_numpy())
|
||||
|
||||
dd_sm = dd[['assessment_date', 'smoking']].dropna().reset_index()
|
||||
dd_sm = dd_sm[dd_sm['smoking'] != -3]
|
||||
if not dd_sm.empty:
|
||||
dd_sm['smoking_status'] = np.select([dd_sm['smoking'] == 1, dd_sm['smoking'] == 2], [8, 7], default=6)
|
||||
data_list.append(dd_sm[['f.eid', 'assessment_date', 'smoking_status']].astype(int).to_numpy())
|
||||
|
||||
dd_al = dd[['assessment_date', 'alcohol']].dropna().reset_index()
|
||||
dd_al = dd_al[dd_al['alcohol'] != -3]
|
||||
if not dd_al.empty:
|
||||
dd_al['alcohol_status'] = np.select([dd_al['alcohol'] == 1, dd_al['alcohol'] < 4], [11, 10], default=9)
|
||||
data_list.append(dd_al[['f.eid', 'assessment_date', 'alcohol_status']].astype(int).to_numpy())
|
||||
|
||||
data = np.vstack(data_list)
|
||||
data = data[np.lexsort((data[:, 1], data[:, 2] == data[:, 2].max(), data[:, 0]))]
|
||||
data = data[data[:, 1] >= 0]
|
||||
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
|
||||
data = data.astype(np.uint32)
|
||||
data.tofile(output_prefix + '.bin')
|
||||
|
||||
# Correctly split train/validation sets
|
||||
unique_ids = np.unique(data[:, 0])
|
||||
split_id = unique_ids[int(len(unique_ids) * train_proportion)]
|
||||
train_val_split = data[:, 0] <= split_id
|
||||
|
||||
data[train_val_split].tofile(output_prefix + '_train.bin')
|
||||
data[~train_val_split].tofile(output_prefix + '_val.bin')
|
170
train.py
Normal file
170
train.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 256 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
n_layer = 8
|
||||
n_head = 8
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 5
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
ignored_token_ids = [0, 1]
|
||||
|
||||
# System parameters
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
config = TrainConfig()
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
# Infer vocab_size from the data (max label + 1)
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(config.device)
|
||||
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
print("Starting training...")
|
||||
for epoch in range(config.max_epoch):
|
||||
# --- Learning Rate Scheduling ---
|
||||
if epoch < config.warmup_epochs:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# --- Training Phase ---
|
||||
model.train()
|
||||
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||
train_steps = 0
|
||||
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||
for event_seq, time_seq in pbar:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
# Forward pass
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss_ce_acc += loss_ce.item()
|
||||
train_loss_surv_acc += loss_survival.item()
|
||||
train_steps += 1
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
||||
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
||||
|
||||
# --- Validation Phase ---
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||
for event_seq, time_seq in pbar_val:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += loss_ce.item()
|
||||
val_loss_surv_acc += loss_survival.item()
|
||||
val_steps += 1
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
|
||||
# --- Early Stopping Check ---
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.")
|
||||
else:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
104
utils.py
Normal file
104
utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
A PyTorch Dataset for handling temporal sequences of patient events.
|
||||
|
||||
This class processes a raw NumPy array of patient records, groups them by
|
||||
patient ID, and prepares them for training by imputing gaps, padding, or
|
||||
truncating sequences to a fixed length.
|
||||
"""
|
||||
|
||||
def __init__(self, data: np.ndarray, block_length: int):
|
||||
"""
|
||||
Initializes the dataset by pre-processing the patient event data.
|
||||
|
||||
Args:
|
||||
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
|
||||
The columns represent (patient_id, time_in_days, event_code).
|
||||
block_length (int): The fixed length for the output sequences.
|
||||
"""
|
||||
self.block_length = block_length
|
||||
|
||||
# Group (time_in_days, event_code) pairs by patient_id.
|
||||
# This pre-processing step allows for efficient lookups in __getitem__.
|
||||
patient_events = defaultdict(list)
|
||||
for patient_id, time, event in data:
|
||||
patient_events[patient_id].append((time, event))
|
||||
|
||||
# Store a list of unique patient_ids to map indices to patients.
|
||||
self.patient_ids = list(patient_events.keys())
|
||||
self.patient_events = dict(patient_events)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the total number of unique patients in the dataset.
|
||||
"""
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Retrieves, processes, and returns a single patient's event sequence.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the patient to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||
both of shape (block_length,).
|
||||
"""
|
||||
# 1. Retrieve and Sort
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||
|
||||
# 2. Impute "No Event" Gaps
|
||||
imputed_sequence = []
|
||||
if not records:
|
||||
# Handle cases with no records for a patient if necessary, though
|
||||
# the constructor logic would typically prevent this.
|
||||
pass
|
||||
else:
|
||||
imputed_sequence.append(records[0])
|
||||
for i in range(len(records) - 1):
|
||||
prev_time, _ = records[i]
|
||||
next_time, _ = records[i+1]
|
||||
time_gap = next_time - prev_time
|
||||
|
||||
# If the gap is 5 years (1826 days) or more, insert "no event" records.
|
||||
if time_gap >= 1826:
|
||||
num_no_event_intervals = time_gap // 1826
|
||||
for j in range(1, num_no_event_intervals + 1):
|
||||
no_event_time = prev_time + j * 1826
|
||||
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
|
||||
|
||||
imputed_sequence.append(records[i+1])
|
||||
|
||||
# 3. Adjust Sequence Length
|
||||
seq_len = len(imputed_sequence)
|
||||
|
||||
if seq_len > self.block_length:
|
||||
# If longer, randomly select a contiguous sub-sequence.
|
||||
start_index = random.randint(0, seq_len - self.block_length)
|
||||
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
|
||||
elif seq_len < self.block_length:
|
||||
# If shorter, pad the sequence at the end.
|
||||
padding_needed = self.block_length - seq_len
|
||||
# Use event_code=0 and time_in_days=36525 for padding.
|
||||
padding = [(36525, 0)] * padding_needed
|
||||
final_sequence = imputed_sequence + padding
|
||||
else:
|
||||
# If equal, use the sequence as is.
|
||||
final_sequence = imputed_sequence
|
||||
|
||||
# 4. Return Tensors
|
||||
# Separate the sequence into event codes and time, then convert to tensors.
|
||||
event_codes = [item[1] for item in final_sequence]
|
||||
time_stamps = [item[0] for item in final_sequence]
|
||||
|
||||
event_tensor = torch.tensor(event_codes, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||
|
||||
return event_tensor, time_tensor
|
Reference in New Issue
Block a user