Compare commits

...

32 Commits

Author SHA1 Message Date
e348086e52 update 2025-10-21 10:30:18 +08:00
a8aa5a2bd6 update 2025-10-21 09:20:43 +08:00
ddb7dbfc67 update 2025-10-20 16:22:15 +08:00
88cccdad2e feat: Optimize AUC evaluation with parallel processing 2025-10-20 16:16:50 +08:00
8f44018bae update evaluate 2025-10-20 13:47:50 +08:00
1c9e2a2fb3 feat: print model config and add evaluation notebook 2025-10-20 10:14:50 +08:00
6b782b86e1 feat: Add model checkpoints and configurations 2025-10-20 09:38:24 +08:00
9a9de170d1 delete 2025-10-18 22:35:42 +08:00
7e57e5d3b1 refactor: Update survival loss calculation in CombinedLoss 2025-10-18 15:21:10 +08:00
14865ac5b6 Refactor: Remove Jupyter Notebook cell markers 2025-10-18 13:32:26 +08:00
dbc3000192 add evaluation scripts. 2025-10-18 13:26:56 +08:00
082c719975 feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling 2025-10-18 12:42:14 +08:00
a631ac6d59 feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
2025-10-18 11:07:59 +08:00
f7356b183c feat: Add command-line arguments to train.py 2025-10-18 10:23:12 +08:00
3390bc025e feat: Add iteration-based training scripts (single and multi-GPU) 2025-10-18 10:05:37 +08:00
a832a45c62 config: Tune hyperparameters for multi-GPU training
Increase model size (n_embd, n_layer, n_head) for the multi-GPU configuration.

Explicitly set AdamW betas to (0.9, 0.99).
2025-10-17 15:37:42 +08:00
d760c45baf feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
2025-10-17 14:09:34 +08:00
053f86f4da config: Add weight decay to training configuration
Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
2025-10-17 13:47:37 +08:00
d4d25ac9c7 feat: Add covariate-aware model and piecewise encoder
Introduce PiecewiseLinearEncoder for continuous variable encoding.

Add CovariateAwareGPT2 to extend TimeAwareGPT2 with static and time-varying covariate processing.

The model combines piecewise linear and sinusoidal encodings for covariates and integrates them via concatenation before a final MLP head.

Reorganize models.py for better logical structure.
2025-10-17 12:04:50 +08:00
fe0304a96a feat: Save model with params in name and log losses 2025-10-17 10:44:17 +08:00
7e8d8d307b chore: Ignore small data files 2025-10-17 10:34:24 +08:00
fc0aef4e71 chore: Add .gitignore 2025-10-17 10:32:42 +08:00
02d84a7eca refactor: Use AdamW optimizer and increase early stopping patience 2025-10-17 10:31:12 +08:00
cb7575a229 feat: Update model and training parameters
In `models.py`:
- Change temporal attention mask to be strictly causal (`<` instead of `<=`).
- Add self-attention for the first token in a sequence to prevent NaNs.

In `train.py`:
- Update hyperparameters:
  - `block_length`: 24 -> 48
  - `n_embd`: 256 -> 120
  - `n_layer`: 8 -> 12
  - `n_head`: 8 -> 12
2025-10-16 18:50:15 +08:00
e2495f43b0 revert 6e0713048a
revert update attn mask
2025-10-16 18:37:55 +08:00
6e0713048a update attn mask 2025-10-16 18:29:48 +08:00
eec406d79f update ignored events. 2025-10-16 17:10:01 +08:00
e3e533c9ec update 2025-10-16 16:58:30 +08:00
b5172392cb update dpp 2025-10-16 16:46:33 +08:00
6b0b86d9d0 add Multi_GPU support. 2025-10-16 16:28:52 +08:00
c7296381b8 Revert "feat: adapt train.py to multi-GPU environment"
This reverts commit b7aad7a774.
2025-10-16 16:23:38 +08:00
2b20299e36 Revert "fix: average loss for multi-GPU training"
This reverts commit 85502561ee.
2025-10-16 16:23:35 +08:00
13 changed files with 3016 additions and 115 deletions

17
.gitignore vendored Normal file
View File

@@ -0,0 +1,17 @@
# IDE settings
.idea/
# Python cache
__pycache__/
# Model checkpoints
*.pt
# Large data files
ukb_delphi.txt
ukb_real.bin
# Small data files
fields.txt
icd10_codes_mod.tsv
labels.csv

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,18 @@
{
"n_layer": 12,
"n_embd": 120,
"n_head": 12,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

View File

@@ -0,0 +1,18 @@
{
"n_layer": 16,
"n_embd": 256,
"n_head": 16,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

File diff suppressed because it is too large Load Diff

499
evaluate_auc.py Normal file
View File

@@ -0,0 +1,499 @@
import scipy.stats
import scipy
import warnings
import torch
from models import TimeAwareGPT2
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
from joblib import Parallel, delayed
def auc(x1, x2):
n1 = len(x1)
n2 = len(x2)
R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1
U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1
if n1 == 0 or n2 == 0:
return np.nan
return U1 / n1 / n2
def get_common_diseases(delphi_labels, filter_min_total=100):
chapters_of_interest = [
"I. Infectious Diseases",
"II. Neoplasms",
"III. Blood & Immune Disorders",
"IV. Metabolic Diseases",
"V. Mental Disorders",
"VI. Nervous System Diseases",
"VII. Eye Diseases",
"VIII. Ear Diseases",
"IX. Circulatory Diseases",
"X. Respiratory Diseases",
"XI. Digestive Diseases",
"XII. Skin Diseases",
"XIII. Musculoskeletal Diseases",
"XIV. Genitourinary Diseases",
"XV. Pregnancy & Childbirth",
"XVI. Perinatal Conditions",
"XVII. Congenital Abnormalities",
"Death",
]
labels_df = delphi_labels[
delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total)
]
return labels_df["index"].tolist()
def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000):
"""
Computes bootstrapped AUC estimates using PyTorch on CUDA.
Parameters:
case: 1D tensor of scores for positive cases
control: 1D tensor of scores for controls
n_bootstrap: Number of bootstrap replicates
Returns:
Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This function requires a GPU.")
# Convert inputs to CUDA tensors
if not torch.is_tensor(case):
case = torch.tensor(case, device="cuda", dtype=torch.float32)
else:
case = case.to("cuda", dtype=torch.float32)
if not torch.is_tensor(control):
control = torch.tensor(control, device="cuda", dtype=torch.float32)
else:
control = control.to("cuda", dtype=torch.float32)
n_case = case.size(0)
n_control = control.size(0)
total = n_case + n_control
# Generate bootstrap samples
boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda")
boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda")
boot_case = case[boot_idx_case]
boot_control = control[boot_idx_control]
combined = torch.cat([boot_case, boot_control], dim=1)
# Mask to identify case entries
mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda")
mask[:, :n_case] = True
# Compute ranks and AUC
ranks = combined.argsort(dim=1).argsort(dim=1)
case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1)
min_case_rank_sum = n_case * (n_case - 1) / 2.0
U = case_ranks_sum - min_case_rank_sum
aucs = U / (n_case * n_control)
return aucs.cpu().tolist()
# AUC comparison adapted from
# https://github.com/Netflix/vmaf/
def compute_midrank(x):
"""Computes midranks.
Args:
x - a 1D numpy array
Returns:
array of midranks
"""
J = np.argsort(x)
Z = x[J]
N = len(x)
T = np.zeros(N, dtype=np.float32)
i = 0
while i < N:
j = i
while j < N and Z[j] == Z[i]:
j += 1
T[i:j] = 0.5 * (i + j - 1)
i = j
T2 = np.empty(N, dtype=np.float32)
# Note(kazeevn) +1 is due to Python using 0-based indexing
# instead of 1-based in the AUC formula in the paper
T2[J] = T + 1
return T2
def fastDeLong(predictions_sorted_transposed, label_1_count):
"""
The fast version of DeLong's method for computing the covariance of
unadjusted AUC.
Args:
predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
sorted such as the examples with label "1" are first
Returns:
(AUC value, DeLong covariance)
Reference:
@article{sun2014fast,
title={Fast Implementation of DeLong's Algorithm for
Comparing the Areas Under Correlated Receiver Operating Characteristic Curves},
author={Xu Sun and Weichao Xu},
journal={IEEE Signal Processing Letters},
volume={21},
number={11},
pages={1389--1393},
year={2014},
publisher={IEEE}
}
"""
# Short variables are named as they are in the paper
m = label_1_count
n = predictions_sorted_transposed.shape[1] - m
positive_examples = predictions_sorted_transposed[:, :m]
negative_examples = predictions_sorted_transposed[:, m:]
k = predictions_sorted_transposed.shape[0]
tx = np.empty([k, m], dtype=np.float32)
ty = np.empty([k, n], dtype=np.float32)
tz = np.empty([k, m + n], dtype=np.float32)
for r in range(k):
tx[r, :] = compute_midrank(positive_examples[r, :])
ty[r, :] = compute_midrank(negative_examples[r, :])
tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
v01 = (tz[:, :m] - tx[:, :]) / n
v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
sx = np.cov(v01)
sy = np.cov(v10)
delongcov = sx / m + sy / n
return aucs, delongcov
def compute_ground_truth_statistics(ground_truth):
assert np.array_equal(np.unique(ground_truth), [0, 1])
order = (-ground_truth).argsort()
label_1_count = int(ground_truth.sum())
return order, label_1_count
def get_auc_delong_var(healthy_scores, diseased_scores):
"""
Computes ROC AUC value and variance using DeLong's method
Args:
healthy_scores: Values for class 0 (healthy/controls)
diseased_scores: Values for class 1 (diseased/cases)
Returns:
AUC value and variance
"""
# Create ground truth labels (1 for diseased, 0 for healthy)
ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores))
predictions = np.concatenate([diseased_scores, healthy_scores])
# Compute statistics needed for DeLong method
order, label_1_count = compute_ground_truth_statistics(ground_truth)
predictions_sorted_transposed = predictions[np.newaxis, order]
# Calculate AUC and covariance
aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count)
assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
return aucs[0], delongcov
def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False):
age_step = age_groups[1] - age_groups[0]
# Indexes of cases with disease k
wk = np.where(d[2] == k)
if len(wk[0]) < 2:
return None
# For controls, we need to exclude cases with disease k
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
# We need to take into account the offset t and use the tokens for prediction that are at least t before the event
if precomputed_idx is None:
pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1
else:
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
valid_indices = pred_idx != -1
pred_idx = pred_idx[valid_indices]
wall = (wall[0][valid_indices], wall[1][valid_indices])
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
zk = d[3][wall] # Target times
x = p[..., j][(wall[0], pred_idx)]
p_idx = wall[0]
out = []
for i, aa in enumerate(age_groups):
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
if not np.any(a):
continue
selected_groups = p_idx[a]
_, unique_indices = np.unique(selected_groups, return_index=True)
a_filtered = a[a]
a_filtered[:] = False
a_filtered[unique_indices] = True
a[a] = a_filtered
is_case = np.zeros_like(x, dtype=bool)
is_case[:len(wk[0])] = True
control = x[~is_case & a]
case = x[is_case & a]
if len(control) == 0 or len(case) == 0:
continue
if use_delong:
auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case)
auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong}
else:
auc_delong_dict = {}
if n_bootstrap > 1:
aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap)
for bootstrap_idx in range(n_bootstrap):
y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx]
out_item = {
"token": k,
"auc": y,
"age": aa,
"n_healthy": len(control),
"n_diseased": len(case),
}
out.append(out_item | auc_delong_dict)
if n_bootstrap > 1:
out_item["bootstrap_idx"] = bootstrap_idx
return out
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
all_aucs = []
for sex, sex_idx in [("female", 2), ("male", 3)]:
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
p_sex = p100k[sex_mask]
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
for j, k in tqdm(
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
):
out = get_calibration_auc(
j,
k,
d100k_sex,
p_sex,
age_groups=age_groups,
offset=offset,
precomputed_idx=precomputed_idx_subset,
n_bootstrap=n_bootstrap,
use_delong=True,
)
if out is None:
continue
for out_item in out:
out_item["sex"] = sex
all_aucs.append(out_item)
return all_aucs
# New internal function that performs the AUC evaluation pipeline.
def evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=100,
disease_chunk_size=200,
age_groups=np.arange(40, 80, 5),
offset=0.1,
batch_size=256,
device="cpu",
seed=1337,
n_bootstrap=1,
meta_info={},
n_jobs=-1,
):
"""
Runs the AUC evaluation pipeline.
Args:
model (torch.nn.Module): The loaded model set to eval().
d100k (tuple): Data batch from get_batch.
delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv").
output_path (str | None): Directory where CSV files will be written. If None, files will not be saved.
diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used.
filter_min_total (int): Minimum total token count to include a token.
disease_chunk_size (int): Maximum chunk size for processing diseases.
age_groups (np.ndarray): Age groups to use in calibration.
offset (float): Offset used in get_calibration_auc.
batch_size (int): Batch size for model forwarding.
device (str): Device identifier.
seed (int): Random seed for reproducibility.
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
n_jobs (int): Number of parallel jobs to run.
Returns:
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
"""
assert n_bootstrap > 0, "n_bootstrap must be greater than 0"
# Set random seeds
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Get common diseases
if diseases_of_interest is None:
diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total)
# Split diseases into chunks for processing
num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size
diseases_chunks = np.array_split(diseases_of_interest, num_chunks)
# Precompute prediction indices for calibration
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
p100k = []
model.to(device)
with torch.no_grad():
for dd in tqdm(
zip(*[torch.split(x, batch_size) for x in d100k]),
desc=f"Model inference",
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
p100k.append(outputs.astype("float16"))
p100k = np.vstack(p100k)
results = Parallel(n_jobs=n_jobs)(
delayed(process_chunk)(
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
)
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
)
all_aucs = [item for sublist in results for item in sublist]
df_auc_unpooled = pd.DataFrame(all_aucs)
for key, value in meta_info.items():
df_auc_unpooled[key] = value
delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']]
df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner")
def aggregate_age_brackets_delong(group):
# For normal distributions, when averaging n of them:
# The variance of the sum is the sum of variances
# The variance of the average is the sum of variances divided by n^2
n = len(group)
mean = group['auc_delong'].mean()
# Since we're taking the average, divide combined variance by n^2
var = group['auc_variance_delong'].sum() / (n**2)
return pd.Series({
'auc': mean,
'auc_variance_delong': var,
'n_samples': n,
'n_diseased': group['n_diseased'].sum(),
'n_healthy': group['n_healthy'].sum(),
})
print('Using DeLong method to calculate AUC confidence intervals..')
df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index()
df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner")
if output_path is not None:
Path(output_path).mkdir(exist_ok=True, parents=True)
df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False)
df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False)
return df_auc_unpooled_merged, df_auc_merged
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction")
# Optional filtering/chunking parameters:
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
args = parser.parse_args()
model_name = args.model_name
output_path = f'auc_evaluation_{model_name}'
dataset_subset_size = args.dataset_subset_size
offset = args.offset
# Create output folder if it doesn't exist.
Path(output_path).mkdir(exist_ok=True, parents=True)
device = "cuda"
seed = 1337
# Load model checkpoint and initialize model.
model = load_model(f'config_{model_name}.json',
f'best_model_{model_name}.pt',
1270)
model.eval()
model = model.to(device)
# Load training and validation data.
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1:
dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation.
d100k = get_batch(val_dataset, slice(dataset_subset_size))
# Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
# Call the internal evaluation function.
df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=args.filter_min_total,
disease_chunk_size=args.disease_chunk_size,
device=device,
seed=seed,
offset=offset,
n_bootstrap=args.n_bootstrap,
n_jobs=args.n_jobs,
)
if __name__ == "__main__":
main()

552
evaluate_models.ipynb Normal file

File diff suppressed because one or more lines are too long

326
models.py
View File

@@ -3,6 +3,10 @@ import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple from typing import Tuple
# =============================================================================
# 1. Component Modules (Building Blocks)
# =============================================================================
class Block(nn.Module): class Block(nn.Module):
""" an unassuming Transformer block """ """ an unassuming Transformer block """
@@ -58,14 +62,8 @@ class AgeSinusoidalEncoding(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
# Pre-calculate the divisor term for the sinusoidal formula. # Pre-calculate the divisor term for the sinusoidal formula.
# The formula for the divisor is 10000^(2i/D), where D is the
# embedding_dim and i is the index for each pair of dimensions.
# i ranges from 0 to D/2 - 1.
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.embedding_dim) divisor = torch.pow(10000, i / self.embedding_dim)
# Register the divisor as a non-trainable buffer. This ensures it is
# moved to the correct device (e.g., GPU) along with the model.
self.register_buffer('divisor', divisor) self.register_buffer('divisor', divisor)
def forward(self, t: torch.Tensor) -> torch.Tensor: def forward(self, t: torch.Tensor) -> torch.Tensor:
@@ -80,49 +78,116 @@ class AgeSinusoidalEncoding(nn.Module):
torch.Tensor: The encoded age tensor of shape torch.Tensor: The encoded age tensor of shape
(batch_size, sequence_length, embedding_dim). (batch_size, sequence_length, embedding_dim).
""" """
# 1. Unit Conversion: Convert age from days to years.
# We use 365.25 to account for leap years.
t_years = t / 365.25 t_years = t / 365.25
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
# The shapes are broadcast to (B, L, D/2).
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
# Divisor: (D/2) -> viewed as (1, 1, D/2)
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
# 3. Sinusoidal Application: Create the final output tensor.
# Initialize an empty tensor to store the embeddings.
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
# Assign cosine of the arguments to the even indices.
output[:, :, 0::2] = torch.cos(args) output[:, :, 0::2] = torch.cos(args)
# Assign sine of the arguments to the odd indices.
output[:, :, 1::2] = torch.sin(args) output[:, :, 1::2] = torch.sin(args)
return output return output
class PiecewiseLinearEncoder(nn.Module):
"""
Encodes continuous variables using piecewise linear encoding.
This module defines bins based on standard normal distribution quantiles,
encodes an input by finding its bin, and calculates its position as a
linear interpolation between boundaries. The result is projected to the
final embedding dimension by a shared linear layer.
"""
def __init__(self, num_bins: int, embedding_dim: int):
"""
Initializes the PiecewiseLinearEncoder module.
Args:
num_bins (int): The number of bins for the encoding.
embedding_dim (int): The dimensionality of the output embedding (D).
"""
super().__init__()
if num_bins <= 0:
raise ValueError("num_bins must be a positive integer.")
self.num_bins = num_bins
self.embedding_dim = embedding_dim
if num_bins > 1:
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
normal_dist = torch.distributions.normal.Normal(0, 1)
boundaries = normal_dist.icdf(quantiles)
else:
boundaries = torch.tensor([])
boundaries = torch.cat([
torch.tensor([float('-inf')]),
boundaries,
torch.tensor([float('inf')])
])
self.register_buffer('boundaries', boundaries)
self.linear = nn.Linear(num_bins, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the piecewise linear encoding.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where * is any
number of batch dimensions and N is the number of continuous
features. Assumed to be pre-scaled.
Returns:
torch.Tensor: Encoded tensor of shape (*, N, D).
"""
original_shape = x.shape
x = x.reshape(-1, original_shape[-1])
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
lower_bounds = self.boundaries[bin_indices]
upper_bounds = self.boundaries[bin_indices + 1]
delta = upper_bounds - lower_bounds + 1e-8
weight_upper = (x - lower_bounds) / delta
weight_lower = 1.0 - weight_upper
is_first_bin = (bin_indices == 0)
is_last_bin = (bin_indices == self.num_bins - 1)
weight_lower[is_first_bin] = 1.0
weight_upper[is_first_bin] = 0.0
weight_lower[is_last_bin] = 0.0
weight_upper[is_last_bin] = 1.0
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
encoded = encoded.view(*original_shape, self.num_bins)
output = self.linear(encoded)
return output
# =============================================================================
# 2. Main Model Architectures
# =============================================================================
class TimeAwareGPT2(nn.Module): class TimeAwareGPT2(nn.Module):
""" """
A time-aware GPT-2 model with custom temporal features. A time-aware GPT-2 model with custom temporal features.
""" """
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
super().__init__() super().__init__()
self.token_pdrop = token_pdrop self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
# Token and positional embeddings
self.wte = nn.Embedding(vocab_size, n_embd) self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop) self.drop = nn.Dropout(pdrop)
# Transformer blocks
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
# Final layer norm and linear head
self.ln_f = nn.LayerNorm(n_embd) self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False) self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.n_embd = n_embd self.n_embd = n_embd
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
@@ -138,46 +203,30 @@ class TimeAwareGPT2(nn.Module):
""" """
B, L = event_seq.size() B, L = event_seq.size()
# 1. Get token embeddings
token_embeddings = self.wte(event_seq) token_embeddings = self.wte(event_seq)
# 2. Apply token dropout (only during training)
if self.training and self.token_pdrop > 0: if self.training and self.token_pdrop > 0:
# Create a mask to randomly zero out entire token embedding vectors
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0 token_embeddings[drop_mask] = 0.0
# 3. Get positional embeddings from time sequence
pos_embeddings = self.age_encoder(time_seq.float()) pos_embeddings = self.age_encoder(time_seq.float())
# 4. Combine embeddings and apply dropout
x = self.drop(token_embeddings + pos_embeddings) x = self.drop(token_embeddings + pos_embeddings)
# 5. Generate attention mask t_i = time_seq.unsqueeze(-1)
# The attention mask combines two conditions: t_j = time_seq.unsqueeze(1)
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i]. time_mask = (t_j < t_i)
# b) Padding mask: Do not attend to positions where the event token is 0. padding_mask = (event_seq != 0).unsqueeze(1)
# a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j <= t_i)
# b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
# Combine the masks. A position (j) can be attended to by a query (i) only if
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
combined_mask = time_mask & padding_mask combined_mask = time_mask & padding_mask
# 6. Pass through transformer blocks is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks: for block in self.blocks:
x = block(x, custom_mask=combined_mask) x = block(x, custom_mask=combined_mask)
# 7. Final layer norm and projection to vocab size
x = self.ln_f(x) x = self.ln_f(x)
logits = self.head(x) logits = self.head(x)
return logits return logits
def get_num_params(self) -> float: def get_num_params(self) -> float:
@@ -186,6 +235,151 @@ class TimeAwareGPT2(nn.Module):
""" """
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
"""
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
return x, t, final_logits
class CovariateAwareGPT2(nn.Module):
"""
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int,
pdrop: float, token_pdrop: float, num_bins: int):
"""
Initializes the CovariateAwareGPT2 model.
Args:
vocab_size (int): Size of the event vocabulary.
n_embd (int): Embedding dimensionality.
n_layer (int): Number of transformer layers.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability for layers.
token_pdrop (float): Dropout probability for input token embeddings.
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
"""
super().__init__()
self.token_pdrop = token_pdrop
self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.n_embd = n_embd
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
self.ln_f = nn.LayerNorm(2 * n_embd)
self.head = nn.Sequential(
nn.Linear(2 * n_embd, n_embd),
nn.GELU(),
nn.Linear(n_embd, vocab_size)
)
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CovariateAwareGPT2 model.
Args:
x (torch.Tensor): Event sequence tensor of shape (B, L).
t (torch.Tensor): Time sequence tensor of shape (B, L).
cov (torch.Tensor): Covariate tensor of shape (B, N).
cov_t (torch.Tensor): Covariate time tensor of shape (B).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = x.size()
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
cov_embed = cov_encoded + cov_t_encoded
token_embeddings = self.wte(x)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(t.float())
seq_embed = self.drop(token_embeddings + pos_embeddings)
t_i = t.unsqueeze(-1)
t_j = t.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (x != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (x != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
block_output = seq_embed
for block in self.blocks:
block_output = block(block_output, custom_mask=combined_mask)
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
final_output = self.ln_f(integrated_embed)
logits = self.head(final_output)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
# =============================================================================
# 3. Loss Function
# =============================================================================
class CombinedLoss(nn.Module): class CombinedLoss(nn.Module):
""" """
Computes a two-part loss: a standard cross-entropy loss for event type Computes a two-part loss: a standard cross-entropy loss for event type
@@ -215,35 +409,23 @@ class CombinedLoss(nn.Module):
Returns: Returns:
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
""" """
# 1. Create a mask to filter out ignored token IDs from loss calculation.
# An element is True if the corresponding label in x is NOT in the ignored list.
mask = torch.ones_like(x, dtype=torch.bool) mask = torch.ones_like(x, dtype=torch.bool)
for token_id in self.ignored_token_ids: for token_id in self.ignored_token_ids:
mask = mask & (x != token_id) mask = mask & (x != token_id)
# If the mask is all False (all tokens are ignored), return zero for both losses.
if not mask.any(): if not mask.any():
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
# 2. Part 1: Cross-Entropy Loss (loss_ce)
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
logits_for_ce = logits.permute(0, 2, 1) logits_for_ce = logits.permute(0, 2, 1)
# Calculate per-element loss without reduction.
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
# Apply the mask and compute the mean of valid elements.
loss_ce = per_element_ce[mask].mean() loss_ce = per_element_ce[mask].mean()
# 3. Part 2: Survival Loss (loss_survival) # Survival loss based on exponential log-likelihood
# Calculate event intensity (lambda) as the sum of exponentiated logits. t_min = 0.1
intensity = torch.sum(torch.exp(logits), dim=2) lse = torch.logsumexp(logits, dim=-1)
lse = -torch.log(torch.exp(-lse) + t_min)
# Calculate per-element survival loss (negative log-likelihood of exponential dist). ldt = -torch.log(t + t_min)
# We add a small epsilon for numerical stability with the log. loss_dt = -(lse - torch.exp(lse - ldt))
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) loss_survival = loss_dt[mask].mean()
# Apply the mask and compute the mean of valid elements.
loss_survival = per_element_survival[mask].mean()
return loss_ce, loss_survival return loss_ce, loss_survival

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
torch
numpy
tqdm
matplotlib
joblib

112
train.py
View File

@@ -1,11 +1,13 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import AdamW
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
import math import math
import tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
@@ -15,12 +17,12 @@ class TrainConfig:
# Data parameters # Data parameters
train_data_path = 'ukb_real_train.bin' train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin' val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length block_length = 48 # Sequence length
# Model parameters # Model parameters
n_embd = 256 n_embd = 120
n_layer = 8 n_layer = 12
n_head = 8 n_head = 12
pdrop = 0.1 pdrop = 0.1
token_pdrop = 0.1 token_pdrop = 0.1
@@ -29,20 +31,62 @@ class TrainConfig:
batch_size = 128 batch_size = 128
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1] ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters # System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script --- # --- Main Training Script ---
def main(): def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
args = parser.parse_args()
config = TrainConfig() config = TrainConfig()
device = torch.device(config.device) config.n_layer = args.n_layer
config.n_embd = args.n_embd
config.n_head = args.n_head
config.max_epoch = args.max_epoch
config.batch_size = args.batch_size
config.lr_initial = args.lr_initial
config.lr_final = args.lr_final
config.weight_decay = args.weight_decay
config.warmup_epochs = args.warmup_epochs
config.early_stopping_patience = args.early_stopping_patience
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading --- # --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -60,7 +104,7 @@ def main():
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model = TimeAwareGPT2(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
@@ -68,19 +112,12 @@ def main():
n_head=config.n_head, n_head=config.n_head,
pdrop=config.pdrop, pdrop=config.pdrop,
token_pdrop=config.token_pdrop token_pdrop=config.token_pdrop
) ).to(config.device)
# --- Multi-GPU Support --- print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
model.to(device)
print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
@@ -109,7 +146,7 @@ def main():
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
for event_seq, time_seq in pbar: for event_seq, time_seq in pbar:
event_seq, time_seq = event_seq.to(device), time_seq.to(device) event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets # Prepare inputs and targets
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
@@ -122,11 +159,6 @@ def main():
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival loss = loss_ce + loss_survival
# When using DataParallel, loss is a vector of losses from each GPU.
# We need to average them to get a single scalar loss.
if isinstance(model, nn.DataParallel):
loss = loss.mean()
# Backward pass and optimization # Backward pass and optimization
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
@@ -151,7 +183,7 @@ def main():
with torch.no_grad(): with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
for event_seq, time_seq in pbar_val: for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(device), time_seq.to(device) event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1] input_times = time_seq[:, :-1]
@@ -173,9 +205,9 @@ def main():
val_losses_surv.append(avg_val_loss_surv) val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss) val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n" print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}") f" Learning Rate: {lr:.6f}")
# --- Early Stopping Check --- # --- Early Stopping Check ---
@@ -183,9 +215,7 @@ def main():
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# Save the underlying model state_dict when using DataParallel torch.save(model.state_dict(), checkpoint_filename)
model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(model_state, 'best_model_checkpoint.pt')
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
@@ -198,14 +228,20 @@ def main():
# --- Save Best Model at the End --- # --- Save Best Model at the End ---
if best_val_loss != float('inf'): if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# Load the state dict into the base model, not the DataParallel wrapper model.load_state_dict(torch.load(checkpoint_filename))
base_model = model.module if isinstance(model, nn.DataParallel) else model print(f"Saving final best model to {model_filename}")
base_model.load_state_dict(torch.load('best_model_checkpoint.pt')) torch.save(model.state_dict(), model_filename)
print("Saving final best model to best_model.pt")
torch.save(base_model.state_dict(), 'best_model.pt')
else: else:
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# --- Plot and Save Loss Curves --- # --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total) num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1) epochs = range(1, num_epochs + 1)
@@ -248,4 +284,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()

218
train_iter.py Normal file
View File

@@ -0,0 +1,218 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.0
token_pdrop = 0.0
# Training parameters
max_iter = 200000
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
# --- 3. Training Loop ---
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
print("\nTraining finished.")
# --- 4. Final Validation ---
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter.png')
print("\nLoss curves saved to loss_curves_iter.png")
if __name__ == '__main__':
main()

View File

@@ -2,6 +2,9 @@ import torch
import numpy as np import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json
from models import TimeAwareGPT2
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
""" """
@@ -39,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
""" """
return len(self.patient_ids) return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx):
""" """
Retrieves, processes, and returns a single patient's event sequence. Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args: Args:
idx (int): The index of the patient to retrieve. idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns: Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence), If idx is an int, a tuple of two torch.long tensors:
both of shape (block_length,). (event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
""" """
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort # 1. Retrieve and Sort
patient_id = self.patient_ids[idx] patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -102,3 +110,80 @@ class PatientEventDataset(torch.utils.data.Dataset):
time_tensor = torch.tensor(time_stamps, dtype=torch.long) time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path, model_path, vocab_size, device='cpu'):
"""
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
Args:
config_path (str): Path to the JSON configuration file.
model_path (str): Path to the saved model state dictionary (.pt file).
vocab_size (int): The vocabulary size used during training.
device (str): The device to load the model onto ('cpu' or 'cuda').
Returns:
(TimeAwareGPT2): The loaded and initialized model.
"""
with open(config_path, 'r') as f:
config_dict = json.load(f)
print(f"Model config: {config_dict}")
# Create a config object from the dictionary
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
config = AttrDict(config_dict)
# Initialize the model with parameters from the config
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
# Load the saved state dictionary
model.load_state_dict(torch.load(model_path, map_location=device))
# Set the model to evaluation mode
model.eval()
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times