Compare commits

..

47 Commits

Author SHA1 Message Date
02be03f784 chore(train): include model_name in saved config filename (config_{model}_n_embd_...json) 2025-10-22 17:56:23 +08:00
4d1fc63667 fix(conv): ensure tensors are contiguous before Conv1d (post-pad and post-transpose) to avoid cuDNN ptrDesc->finalize 2025-10-22 17:54:04 +08:00
dd58ced9b9 fix(attn): convert boolean attention mask to additive float (-1e9) to avoid cudnn ptrDesc->finalize on some GPUs 2025-10-22 17:51:12 +08:00
3bef72f50b feat(model): add TimeAwareGPT2TemporalConv using TemporalConvEncoder; wire into train.py and utils.load_model; add model_name to configs; translate CN comments and add math import 2025-10-22 17:34:06 +08:00
a81da36657 R: order ICD-10 chapters by chapter number (Roman numeral prefix) in boxplots; handle 'Death' and unknowns at end 2025-10-22 16:08:00 +08:00
b954b4b3e7 R: add orientation arg to boxplot script; toggle coord_flip for rotated (horizontal) vs vertical layouts and adjust axis labels 2025-10-22 15:58:15 +08:00
f8e0104d6b R: add plot_auc_boxplots_by_chapter.R to compare AUC distributions by ICD-10 chapter for 1-year and no-gap datasets; outputs individual and grid plots 2025-10-22 15:52:04 +08:00
262a7db0da R: add plot_model_comparison_no_gap.R mirroring 1-year script; generate per-plot and 2x2 cowplot grid from model_comparison_auc_no_gap.csv 2025-10-22 15:43:16 +08:00
9917b3ab63 R: extend 1-year AUC plot script with additional comparisons (256_l, 120_l vs Delphi; 120 vs 120_l; 256 vs 256_l) and generic xy helper 2025-10-22 15:32:02 +08:00
8316326d7e Add R script to plot 1-year AUC comparisons with ggplot2 (256 vs Delphi and 120 vs Delphi) 2025-10-22 15:25:58 +08:00
6dd5eb95c7 evaluate_models.ipynb: update load_model usage to new signature (infer checkpoint from config) 2025-10-22 13:27:28 +08:00
5b0642eb6e Add train_ddp.py: DistributedDataParallel multi-GPU training with distributed samplers, rank-0 checkpointing, and aggregated metrics 2025-10-22 11:54:48 +08:00
93cf2018d2 evaluate_auc: adapt to new utils.load_model signature (infer checkpoint from config) 2025-10-22 11:40:46 +08:00
6801e5bdbb utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint 2025-10-22 11:39:10 +08:00
92a5bd4a83 evaluate_auc: use new utils.load_model (infer vocab, model variants) and dynamic device; remove unused imports 2025-10-22 11:35:52 +08:00
dfdf64da9a Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly 2025-10-22 11:30:59 +08:00
bd88daa8c2 update models and training scripts 2025-10-22 08:36:55 +08:00
e348086e52 update 2025-10-21 10:30:18 +08:00
a8aa5a2bd6 update 2025-10-21 09:20:43 +08:00
ddb7dbfc67 update 2025-10-20 16:22:15 +08:00
88cccdad2e feat: Optimize AUC evaluation with parallel processing 2025-10-20 16:16:50 +08:00
8f44018bae update evaluate 2025-10-20 13:47:50 +08:00
1c9e2a2fb3 feat: print model config and add evaluation notebook 2025-10-20 10:14:50 +08:00
6b782b86e1 feat: Add model checkpoints and configurations 2025-10-20 09:38:24 +08:00
9a9de170d1 delete 2025-10-18 22:35:42 +08:00
7e57e5d3b1 refactor: Update survival loss calculation in CombinedLoss 2025-10-18 15:21:10 +08:00
14865ac5b6 Refactor: Remove Jupyter Notebook cell markers 2025-10-18 13:32:26 +08:00
dbc3000192 add evaluation scripts. 2025-10-18 13:26:56 +08:00
082c719975 feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling 2025-10-18 12:42:14 +08:00
a631ac6d59 feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
2025-10-18 11:07:59 +08:00
f7356b183c feat: Add command-line arguments to train.py 2025-10-18 10:23:12 +08:00
3390bc025e feat: Add iteration-based training scripts (single and multi-GPU) 2025-10-18 10:05:37 +08:00
a832a45c62 config: Tune hyperparameters for multi-GPU training
Increase model size (n_embd, n_layer, n_head) for the multi-GPU configuration.

Explicitly set AdamW betas to (0.9, 0.99).
2025-10-17 15:37:42 +08:00
d760c45baf feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
2025-10-17 14:09:34 +08:00
053f86f4da config: Add weight decay to training configuration
Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
2025-10-17 13:47:37 +08:00
d4d25ac9c7 feat: Add covariate-aware model and piecewise encoder
Introduce PiecewiseLinearEncoder for continuous variable encoding.

Add CovariateAwareGPT2 to extend TimeAwareGPT2 with static and time-varying covariate processing.

The model combines piecewise linear and sinusoidal encodings for covariates and integrates them via concatenation before a final MLP head.

Reorganize models.py for better logical structure.
2025-10-17 12:04:50 +08:00
fe0304a96a feat: Save model with params in name and log losses 2025-10-17 10:44:17 +08:00
7e8d8d307b chore: Ignore small data files 2025-10-17 10:34:24 +08:00
fc0aef4e71 chore: Add .gitignore 2025-10-17 10:32:42 +08:00
02d84a7eca refactor: Use AdamW optimizer and increase early stopping patience 2025-10-17 10:31:12 +08:00
cb7575a229 feat: Update model and training parameters
In `models.py`:
- Change temporal attention mask to be strictly causal (`<` instead of `<=`).
- Add self-attention for the first token in a sequence to prevent NaNs.

In `train.py`:
- Update hyperparameters:
  - `block_length`: 24 -> 48
  - `n_embd`: 256 -> 120
  - `n_layer`: 8 -> 12
  - `n_head`: 8 -> 12
2025-10-16 18:50:15 +08:00
e2495f43b0 revert 6e0713048a
revert update attn mask
2025-10-16 18:37:55 +08:00
6e0713048a update attn mask 2025-10-16 18:29:48 +08:00
eec406d79f update ignored events. 2025-10-16 17:10:01 +08:00
e3e533c9ec update 2025-10-16 16:58:30 +08:00
b5172392cb update dpp 2025-10-16 16:46:33 +08:00
6b0b86d9d0 add Multi_GPU support. 2025-10-16 16:28:52 +08:00
17 changed files with 4270 additions and 99 deletions

17
.gitignore vendored Normal file
View File

@@ -0,0 +1,17 @@
# IDE settings
.idea/
# Python cache
__pycache__/
# Model checkpoints
*.pt
# Large data files
ukb_delphi.txt
ukb_real.bin
# Small data files
fields.txt
icd10_codes_mod.tsv
labels.csv

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,19 @@
{
"model_name": "TimeAwareGPT2",
"n_layer": 12,
"n_embd": 120,
"n_head": 12,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

View File

@@ -0,0 +1,19 @@
{
"model_name": "TimeAwareGPT2",
"n_layer": 16,
"n_embd": 256,
"n_head": 16,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

File diff suppressed because it is too large Load Diff

496
evaluate_auc.py Normal file
View File

@@ -0,0 +1,496 @@
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
from joblib import Parallel, delayed
def auc(x1, x2):
n1 = len(x1)
n2 = len(x2)
R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1
U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1
if n1 == 0 or n2 == 0:
return np.nan
return U1 / n1 / n2
def get_common_diseases(delphi_labels, filter_min_total=100):
chapters_of_interest = [
"I. Infectious Diseases",
"II. Neoplasms",
"III. Blood & Immune Disorders",
"IV. Metabolic Diseases",
"V. Mental Disorders",
"VI. Nervous System Diseases",
"VII. Eye Diseases",
"VIII. Ear Diseases",
"IX. Circulatory Diseases",
"X. Respiratory Diseases",
"XI. Digestive Diseases",
"XII. Skin Diseases",
"XIII. Musculoskeletal Diseases",
"XIV. Genitourinary Diseases",
"XV. Pregnancy & Childbirth",
"XVI. Perinatal Conditions",
"XVII. Congenital Abnormalities",
"Death",
]
labels_df = delphi_labels[
delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total)
]
return labels_df["index"].tolist()
def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000):
"""
Computes bootstrapped AUC estimates using PyTorch on CUDA.
Parameters:
case: 1D tensor of scores for positive cases
control: 1D tensor of scores for controls
n_bootstrap: Number of bootstrap replicates
Returns:
Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This function requires a GPU.")
# Convert inputs to CUDA tensors
if not torch.is_tensor(case):
case = torch.tensor(case, device="cuda", dtype=torch.float32)
else:
case = case.to("cuda", dtype=torch.float32)
if not torch.is_tensor(control):
control = torch.tensor(control, device="cuda", dtype=torch.float32)
else:
control = control.to("cuda", dtype=torch.float32)
n_case = case.size(0)
n_control = control.size(0)
total = n_case + n_control
# Generate bootstrap samples
boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda")
boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda")
boot_case = case[boot_idx_case]
boot_control = control[boot_idx_control]
combined = torch.cat([boot_case, boot_control], dim=1)
# Mask to identify case entries
mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda")
mask[:, :n_case] = True
# Compute ranks and AUC
ranks = combined.argsort(dim=1).argsort(dim=1)
case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1)
min_case_rank_sum = n_case * (n_case - 1) / 2.0
U = case_ranks_sum - min_case_rank_sum
aucs = U / (n_case * n_control)
return aucs.cpu().tolist()
# AUC comparison adapted from
# https://github.com/Netflix/vmaf/
def compute_midrank(x):
"""Computes midranks.
Args:
x - a 1D numpy array
Returns:
array of midranks
"""
J = np.argsort(x)
Z = x[J]
N = len(x)
T = np.zeros(N, dtype=np.float32)
i = 0
while i < N:
j = i
while j < N and Z[j] == Z[i]:
j += 1
T[i:j] = 0.5 * (i + j - 1)
i = j
T2 = np.empty(N, dtype=np.float32)
# Note(kazeevn) +1 is due to Python using 0-based indexing
# instead of 1-based in the AUC formula in the paper
T2[J] = T + 1
return T2
def fastDeLong(predictions_sorted_transposed, label_1_count):
"""
The fast version of DeLong's method for computing the covariance of
unadjusted AUC.
Args:
predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
sorted such as the examples with label "1" are first
Returns:
(AUC value, DeLong covariance)
Reference:
@article{sun2014fast,
title={Fast Implementation of DeLong's Algorithm for
Comparing the Areas Under Correlated Receiver Operating Characteristic Curves},
author={Xu Sun and Weichao Xu},
journal={IEEE Signal Processing Letters},
volume={21},
number={11},
pages={1389--1393},
year={2014},
publisher={IEEE}
}
"""
# Short variables are named as they are in the paper
m = label_1_count
n = predictions_sorted_transposed.shape[1] - m
positive_examples = predictions_sorted_transposed[:, :m]
negative_examples = predictions_sorted_transposed[:, m:]
k = predictions_sorted_transposed.shape[0]
tx = np.empty([k, m], dtype=np.float32)
ty = np.empty([k, n], dtype=np.float32)
tz = np.empty([k, m + n], dtype=np.float32)
for r in range(k):
tx[r, :] = compute_midrank(positive_examples[r, :])
ty[r, :] = compute_midrank(negative_examples[r, :])
tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
v01 = (tz[:, :m] - tx[:, :]) / n
v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
sx = np.cov(v01)
sy = np.cov(v10)
delongcov = sx / m + sy / n
return aucs, delongcov
def compute_ground_truth_statistics(ground_truth):
assert np.array_equal(np.unique(ground_truth), [0, 1])
order = (-ground_truth).argsort()
label_1_count = int(ground_truth.sum())
return order, label_1_count
def get_auc_delong_var(healthy_scores, diseased_scores):
"""
Computes ROC AUC value and variance using DeLong's method
Args:
healthy_scores: Values for class 0 (healthy/controls)
diseased_scores: Values for class 1 (diseased/cases)
Returns:
AUC value and variance
"""
# Create ground truth labels (1 for diseased, 0 for healthy)
ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores))
predictions = np.concatenate([diseased_scores, healthy_scores])
# Compute statistics needed for DeLong method
order, label_1_count = compute_ground_truth_statistics(ground_truth)
predictions_sorted_transposed = predictions[np.newaxis, order]
# Calculate AUC and covariance
aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count)
assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
return aucs[0], delongcov
def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False):
age_step = age_groups[1] - age_groups[0]
# Indexes of cases with disease k
wk = np.where(d[2] == k)
if len(wk[0]) < 2:
return None
# For controls, we need to exclude cases with disease k
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
# We need to take into account the offset t and use the tokens for prediction that are at least t before the event
if precomputed_idx is None:
pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1
else:
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
valid_indices = pred_idx != -1
pred_idx = pred_idx[valid_indices]
wall = (wall[0][valid_indices], wall[1][valid_indices])
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
zk = d[3][wall] # Target times
x = p[..., j][(wall[0], pred_idx)]
p_idx = wall[0]
out = []
for i, aa in enumerate(age_groups):
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
if not np.any(a):
continue
selected_groups = p_idx[a]
_, unique_indices = np.unique(selected_groups, return_index=True)
a_filtered = a[a]
a_filtered[:] = False
a_filtered[unique_indices] = True
a[a] = a_filtered
is_case = np.zeros_like(x, dtype=bool)
is_case[:len(wk[0])] = True
control = x[~is_case & a]
case = x[is_case & a]
if len(control) == 0 or len(case) == 0:
continue
if use_delong:
auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case)
auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong}
else:
auc_delong_dict = {}
if n_bootstrap > 1:
aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap)
for bootstrap_idx in range(n_bootstrap):
y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx]
out_item = {
"token": k,
"auc": y,
"age": aa,
"n_healthy": len(control),
"n_diseased": len(case),
}
out.append(out_item | auc_delong_dict)
if n_bootstrap > 1:
out_item["bootstrap_idx"] = bootstrap_idx
return out
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
all_aucs = []
for sex, sex_idx in [("female", 2), ("male", 3)]:
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
p_sex = p100k[sex_mask]
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
for j, k in tqdm(
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
):
out = get_calibration_auc(
j,
k,
d100k_sex,
p_sex,
age_groups=age_groups,
offset=offset,
precomputed_idx=precomputed_idx_subset,
n_bootstrap=n_bootstrap,
use_delong=True,
)
if out is None:
continue
for out_item in out:
out_item["sex"] = sex
all_aucs.append(out_item)
return all_aucs
# New internal function that performs the AUC evaluation pipeline.
def evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=100,
disease_chunk_size=200,
age_groups=np.arange(40, 80, 5),
offset=0.1,
batch_size=256,
device="cpu",
seed=1337,
n_bootstrap=1,
meta_info={},
n_jobs=-1,
):
"""
Runs the AUC evaluation pipeline.
Args:
model (torch.nn.Module): The loaded model set to eval().
d100k (tuple): Data batch from get_batch.
delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv").
output_path (str | None): Directory where CSV files will be written. If None, files will not be saved.
diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used.
filter_min_total (int): Minimum total token count to include a token.
disease_chunk_size (int): Maximum chunk size for processing diseases.
age_groups (np.ndarray): Age groups to use in calibration.
offset (float): Offset used in get_calibration_auc.
batch_size (int): Batch size for model forwarding.
device (str): Device identifier.
seed (int): Random seed for reproducibility.
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
n_jobs (int): Number of parallel jobs to run.
Returns:
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
"""
assert n_bootstrap > 0, "n_bootstrap must be greater than 0"
# Set random seeds
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Get common diseases
if diseases_of_interest is None:
diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total)
# Split diseases into chunks for processing
num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size
diseases_chunks = np.array_split(diseases_of_interest, num_chunks)
# Precompute prediction indices for calibration
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
p100k = []
model.to(device)
with torch.no_grad():
for dd in tqdm(
zip(*[torch.split(x, batch_size) for x in d100k]),
desc=f"Model inference",
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
p100k.append(outputs.astype("float16"))
p100k = np.vstack(p100k)
results = Parallel(n_jobs=n_jobs)(
delayed(process_chunk)(
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
)
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
)
all_aucs = [item for sublist in results for item in sublist]
df_auc_unpooled = pd.DataFrame(all_aucs)
for key, value in meta_info.items():
df_auc_unpooled[key] = value
delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']]
df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner")
def aggregate_age_brackets_delong(group):
# For normal distributions, when averaging n of them:
# The variance of the sum is the sum of variances
# The variance of the average is the sum of variances divided by n^2
n = len(group)
mean = group['auc_delong'].mean()
# Since we're taking the average, divide combined variance by n^2
var = group['auc_variance_delong'].sum() / (n**2)
return pd.Series({
'auc': mean,
'auc_variance_delong': var,
'n_samples': n,
'n_diseased': group['n_diseased'].sum(),
'n_healthy': group['n_healthy'].sum(),
})
print('Using DeLong method to calculate AUC confidence intervals..')
df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index()
df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner")
if output_path is not None:
Path(output_path).mkdir(exist_ok=True, parents=True)
df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False)
df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False)
return df_auc_unpooled_merged, df_auc_merged
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction")
# Optional filtering/chunking parameters:
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
args = parser.parse_args()
model_name = args.model_name
output_path = f'auc_evaluation_{model_name}'
dataset_subset_size = args.dataset_subset_size
offset = args.offset
# Create output folder if it doesn't exist.
Path(output_path).mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 1337
# Load model checkpoint and initialize model.
model = load_model(
config_path=f'config_{model_name}.json',
device=device,
)
model.eval()
model = model.to(device)
# Load training and validation data.
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1:
dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation.
d100k = get_batch(val_dataset, slice(dataset_subset_size))
# Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
# Call the internal evaluation function.
df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=args.filter_min_total,
disease_chunk_size=args.disease_chunk_size,
device=device,
seed=seed,
offset=offset,
n_bootstrap=args.n_bootstrap,
n_jobs=args.n_jobs,
)
if __name__ == "__main__":
main()

718
evaluate_models.ipynb Normal file

File diff suppressed because one or more lines are too long

565
models.py
View File

@@ -1,7 +1,83 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple from typing import Tuple, Optional
import math
# =============================================================================
# 1. Component Modules (Building Blocks)
# =============================================================================
class CausalConv1d(nn.Module):
def __init__(self, channels, kernel_size, groups=1):
super().__init__()
self.pad = kernel_size - 1
self.conv = nn.Conv1d(
channels, channels, kernel_size,
padding=0, groups=groups
)
def forward(self, x): # x: (B, C, L)
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
x = x.contiguous()
return self.conv(x)
class DepthwiseSeparableCausalConvBlock(nn.Module):
def __init__(self, d_model, kernel_size=5, dropout=0.1):
super().__init__()
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
self.act = nn.GELU()
self.ln = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, L, D)
y = x.transpose(1, 2).contiguous() # (B, D, L)
y = self.dw(y) # (B, D, L)
y = self.pw(y.contiguous()) # (B, D, L)
y = y.transpose(1, 2).contiguous() # (B, L, D)
y = self.act(y)
y = self.dropout(y)
return self.ln(x + y) # residual connection + layer norm (LN)
class TimeFeatureProjector(nn.Module):
"""
Projects scalar time t and its increment Δt into d_model dimensions.
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
"""
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
super().__init__()
self.dt_clip = dt_clip
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
k = fourier_dim // 2
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
self.register_buffer("freqs", freqs, persistent=False)
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
self.ln = nn.LayerNorm(d_model)
def forward(self, t): # t: (B, L) continuous timestamps/steps
# compute increments Δt and stabilize
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
dt = torch.clamp(dt, min=0.) # ensure non-negative
# normalize/stabilize with log compression
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
scal_feat = self.scalar_proj(scal) # (B, L, D)
# Fixed-frequency sin/cos to capture absolute/relative periodicity
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
# (B, L, K)
wt = t[..., None] * self.freqs
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
# gated fusion + layer norm
h = scal_feat + torch.tanh(self.gate) * fourier_feat
return self.ln(h) # (B, L, D)
class Block(nn.Module): class Block(nn.Module):
""" an unassuming Transformer block """ """ an unassuming Transformer block """
@@ -25,8 +101,10 @@ class Block(nn.Module):
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
normed_x = self.ln_1(x) normed_x = self.ln_1(x)
attn_mask = ~custom_mask # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked.
mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask
attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9)
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
x = x + self.resid_dropout(attn_output) x = x + self.resid_dropout(attn_output)
@@ -58,14 +136,8 @@ class AgeSinusoidalEncoding(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
# Pre-calculate the divisor term for the sinusoidal formula. # Pre-calculate the divisor term for the sinusoidal formula.
# The formula for the divisor is 10000^(2i/D), where D is the
# embedding_dim and i is the index for each pair of dimensions.
# i ranges from 0 to D/2 - 1.
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.embedding_dim) divisor = torch.pow(10000, i / self.embedding_dim)
# Register the divisor as a non-trainable buffer. This ensures it is
# moved to the correct device (e.g., GPU) along with the model.
self.register_buffer('divisor', divisor) self.register_buffer('divisor', divisor)
def forward(self, t: torch.Tensor) -> torch.Tensor: def forward(self, t: torch.Tensor) -> torch.Tensor:
@@ -80,49 +152,204 @@ class AgeSinusoidalEncoding(nn.Module):
torch.Tensor: The encoded age tensor of shape torch.Tensor: The encoded age tensor of shape
(batch_size, sequence_length, embedding_dim). (batch_size, sequence_length, embedding_dim).
""" """
# 1. Unit Conversion: Convert age from days to years.
# We use 365.25 to account for leap years.
t_years = t / 365.25 t_years = t / 365.25
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
# The shapes are broadcast to (B, L, D/2).
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
# Divisor: (D/2) -> viewed as (1, 1, D/2)
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
# 3. Sinusoidal Application: Create the final output tensor.
# Initialize an empty tensor to store the embeddings.
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
# Assign cosine of the arguments to the even indices.
output[:, :, 0::2] = torch.cos(args) output[:, :, 0::2] = torch.cos(args)
# Assign sine of the arguments to the odd indices.
output[:, :, 1::2] = torch.sin(args) output[:, :, 1::2] = torch.sin(args)
return output return output
class LearnableAgeEncoding(nn.Module):
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
self.base_dim = base_dim
self.final_dim = final_dim or base_dim
hidden_dim = hidden_dim or base_dim
if hidden_dim <= 0:
raise ValueError("hidden_dim must be a positive integer.")
if self.final_dim <= 0:
raise ValueError("final_dim must be a positive integer.")
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
mlp_layers = [
nn.Linear(base_dim, hidden_dim),
nn.GELU(),
]
if dropout > 0.0:
mlp_layers.append(nn.Dropout(dropout))
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, t: torch.Tensor) -> torch.Tensor:
sin_embed = self.sinusoidal(t)
flat_embed = sin_embed.reshape(-1, self.base_dim)
projected = self.mlp(flat_embed)
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
class PiecewiseLinearEncoder(nn.Module):
"""
Encodes continuous variables using piecewise linear encoding.
This module defines bins based on standard normal distribution quantiles,
encodes an input by finding its bin, and calculates its position as a
linear interpolation between boundaries. The result is projected to the
final embedding dimension by a shared linear layer.
"""
def __init__(self, num_bins: int, embedding_dim: int):
"""
Initializes the PiecewiseLinearEncoder module.
Args:
num_bins (int): The number of bins for the encoding.
embedding_dim (int): The dimensionality of the output embedding (D).
"""
super().__init__()
if num_bins <= 0:
raise ValueError("num_bins must be a positive integer.")
self.num_bins = num_bins
self.embedding_dim = embedding_dim
if num_bins > 1:
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
normal_dist = torch.distributions.normal.Normal(0, 1)
boundaries = normal_dist.icdf(quantiles)
else:
boundaries = torch.tensor([])
boundaries = torch.cat([
torch.tensor([float('-inf')]),
boundaries,
torch.tensor([float('inf')])
])
self.register_buffer('boundaries', boundaries)
self.linear = nn.Linear(num_bins, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the piecewise linear encoding.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where * is any
number of batch dimensions and N is the number of continuous
features. Assumed to be pre-scaled.
Returns:
torch.Tensor: Encoded tensor of shape (*, N, D).
"""
original_shape = x.shape
x = x.reshape(-1, original_shape[-1])
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
lower_bounds = self.boundaries[bin_indices]
upper_bounds = self.boundaries[bin_indices + 1]
delta = upper_bounds - lower_bounds + 1e-8
weight_upper = (x - lower_bounds) / delta
weight_lower = 1.0 - weight_upper
is_first_bin = (bin_indices == 0)
is_last_bin = (bin_indices == self.num_bins - 1)
weight_lower[is_first_bin] = 1.0
weight_upper[is_first_bin] = 0.0
weight_lower[is_last_bin] = 0.0
weight_upper[is_last_bin] = 1.0
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
encoded = encoded.view(*original_shape, self.num_bins)
output = self.linear(encoded)
return output
class TemporalConvEncoder(nn.Module):
"""
Inputs:
x: (B, L) - event/token ids
t: (B, L) - timestamps (real-valued) or step indices
Output:
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_layers: int = 2,
kernel_size: int = 5,
dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
self.ln_in = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
blocks = []
for _ in range(n_layers):
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
self.blocks = nn.ModuleList(blocks)
def forward(self, x, t, attention_mask=None):
"""
attention_mask: (B, L) 1=keep, 0=padding
"""
tok = self.token_emb(x) # (B, L, D)
tim = self.time_proj(t) # (B, L, D)
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
h = self.fuse(h) # (B, L, D)
h = self.ln_in(h)
h = self.dropout(h)
# Optional: zero-out padding positions before convolutions to avoid leakage
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
for blk in self.blocks:
h = blk(h) # (B, L, D)
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1).type_as(h)
return h # (B, L, D), directly usable as attention layer input
# =============================================================================
# 2. Main Model Architectures
# =============================================================================
class TimeAwareGPT2(nn.Module): class TimeAwareGPT2(nn.Module):
""" """
A time-aware GPT-2 model with custom temporal features. A time-aware GPT-2 model with custom temporal features.
""" """
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
super().__init__() super().__init__()
self.token_pdrop = token_pdrop self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
# Token and positional embeddings
self.wte = nn.Embedding(vocab_size, n_embd) self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd)
self.drop = nn.Dropout(pdrop) self.drop = nn.Dropout(pdrop)
# Transformer blocks
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
# Final layer norm and linear head
self.ln_f = nn.LayerNorm(n_embd) self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False) self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.n_embd = n_embd self.n_embd = n_embd
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
@@ -138,46 +365,30 @@ class TimeAwareGPT2(nn.Module):
""" """
B, L = event_seq.size() B, L = event_seq.size()
# 1. Get token embeddings
token_embeddings = self.wte(event_seq) token_embeddings = self.wte(event_seq)
# 2. Apply token dropout (only during training)
if self.training and self.token_pdrop > 0: if self.training and self.token_pdrop > 0:
# Create a mask to randomly zero out entire token embedding vectors
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0 token_embeddings[drop_mask] = 0.0
# 3. Get positional embeddings from time sequence
pos_embeddings = self.age_encoder(time_seq.float()) pos_embeddings = self.age_encoder(time_seq.float())
# 4. Combine embeddings and apply dropout
x = self.drop(token_embeddings + pos_embeddings) x = self.drop(token_embeddings + pos_embeddings)
# 5. Generate attention mask t_i = time_seq.unsqueeze(-1)
# The attention mask combines two conditions: t_j = time_seq.unsqueeze(1)
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i]. time_mask = (t_j < t_i)
# b) Padding mask: Do not attend to positions where the event token is 0. padding_mask = (event_seq != 0).unsqueeze(1)
# a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j <= t_i)
# b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
# Combine the masks. A position (j) can be attended to by a query (i) only if
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
combined_mask = time_mask & padding_mask combined_mask = time_mask & padding_mask
# 6. Pass through transformer blocks is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks: for block in self.blocks:
x = block(x, custom_mask=combined_mask) x = block(x, custom_mask=combined_mask)
# 7. Final layer norm and projection to vocab size
x = self.ln_f(x) x = self.ln_f(x)
logits = self.head(x) logits = self.head(x)
return logits return logits
def get_num_params(self) -> float: def get_num_params(self) -> float:
@@ -186,6 +397,222 @@ class TimeAwareGPT2(nn.Module):
""" """
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
"""
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
return x, t, final_logits
class TimeAwareGPT2Learnable(TimeAwareGPT2):
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.age_encoder = LearnableAgeEncoding(
base_dim=self.n_embd,
hidden_dim=2 * self.n_embd,
final_dim=self.n_embd,
)
# =============================================================================
# 3. Loss Function
# =============================================================================
class TimeAwareGPT2TemporalConv(nn.Module):
"""
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
event and time sequences before Transformer attention blocks.
Inputs:
- event_seq: (B, L) token ids (0 treated as padding)
- time_seq: (B, L) timestamps or step indices (float)
Output:
- logits: (B, L, vocab_size)
"""
def __init__(
self,
vocab_size: int,
n_embd: int,
n_layer: int,
n_head: int,
pdrop: float,
token_pdrop: float,
ignore_tokens: Optional[list[int]] = None,
*,
conv_layers: int = 2,
kernel_size: int = 5,
conv_dropout: float = 0.1,
fourier_dim: int = 32,
pad_id: int = 0,
):
super().__init__()
self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.n_embd = n_embd
# Temporal convolutional encoder to build inputs_embeds
self.temporal_encoder = TemporalConvEncoder(
vocab_size=vocab_size,
d_model=n_embd,
n_layers=conv_layers,
kernel_size=kernel_size,
dropout=conv_dropout,
fourier_dim=fourier_dim,
pad_id=pad_id,
)
# Transformer stack on top of temporal features
self.drop = nn.Dropout(pdrop)
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
B, L = event_seq.size()
# Encoder features as inputs_embeds
attention_mask = (event_seq != 0)
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
x = self.drop(x)
# Time-aware causal mask as before
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (event_seq != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
# Ensure at least self-attention on non-padding rows
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
for block in self.blocks:
x = block(x, custom_mask=combined_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
def get_num_params(self) -> float:
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(
self,
x: torch.Tensor,
t: torch.Tensor,
max_new_tokens: int = 100,
max_age: float = 85 * 365.25,
no_repeat: bool = True,
termination_tokens: Optional[list[int]] = None,
top_k: Optional[int] = None,
):
"""Greedy-like generation with optional no-repeat and termination tokens."""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
# Sample a time increment proxy as in original implementation
t_next_dist = torch.clamp(
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
min=0,
max=365 * 80,
)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack(
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
).transpose(0, 1)
return x, t, final_logits
class CombinedLoss(nn.Module): class CombinedLoss(nn.Module):
""" """
Computes a two-part loss: a standard cross-entropy loss for event type Computes a two-part loss: a standard cross-entropy loss for event type
@@ -215,35 +642,23 @@ class CombinedLoss(nn.Module):
Returns: Returns:
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
""" """
# 1. Create a mask to filter out ignored token IDs from loss calculation.
# An element is True if the corresponding label in x is NOT in the ignored list.
mask = torch.ones_like(x, dtype=torch.bool) mask = torch.ones_like(x, dtype=torch.bool)
for token_id in self.ignored_token_ids: for token_id in self.ignored_token_ids:
mask = mask & (x != token_id) mask = mask & (x != token_id)
# If the mask is all False (all tokens are ignored), return zero for both losses.
if not mask.any(): if not mask.any():
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
# 2. Part 1: Cross-Entropy Loss (loss_ce)
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
logits_for_ce = logits.permute(0, 2, 1) logits_for_ce = logits.permute(0, 2, 1)
# Calculate per-element loss without reduction.
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
# Apply the mask and compute the mean of valid elements.
loss_ce = per_element_ce[mask].mean() loss_ce = per_element_ce[mask].mean()
# 3. Part 2: Survival Loss (loss_survival) # Survival loss based on exponential log-likelihood
# Calculate event intensity (lambda) as the sum of exponentiated logits. t_min = 0.1
intensity = torch.sum(torch.exp(logits), dim=2) lse = torch.logsumexp(logits, dim=-1)
lse = -torch.log(torch.exp(-lse) + t_min)
# Calculate per-element survival loss (negative log-likelihood of exponential dist). ldt = -torch.log(t + t_min)
# We add a small epsilon for numerical stability with the log. loss_dt = -(lse - torch.exp(lse - ldt))
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) loss_survival = loss_dt[mask].mean()
# Apply the mask and compute the mean of valid elements.
loss_survival = per_element_survival[mask].mean()
return loss_ce, loss_survival return loss_ce, loss_survival

View File

@@ -0,0 +1,160 @@
# Compare AUC distributions between models by ICD-10 chapter (1-year and no-gap)
# Usage:
# Rscript plot_auc_boxplots_by_chapter.R [one_year_csv] [no_gap_csv] [output_dir]
# Defaults:
# one_year_csv = "model_comparison_auc_1year.csv"
# no_gap_csv = "model_comparison_auc_no_gap.csv"
# output_dir = current working directory (".")
suppressPackageStartupMessages({
library(ggplot2)
library(cowplot)
})
args <- commandArgs(trailingOnly = TRUE)
one_year_csv <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
no_gap_csv <- if (length(args) >= 2) args[2] else "model_comparison_auc_no_gap.csv"
out_dir <- if (length(args) >= 3) args[3] else "."
orientation <- if (length(args) >= 4) tolower(args[4]) else "vertical" # "horizontal" (flipped) or "vertical"
if (!dir.exists(out_dir)) {
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
}
read_csv_safe <- function(path) {
tryCatch({
read.csv(path, check.names = FALSE)
}, error = function(e) {
stop(sprintf("Failed to read CSV at '%s': %s", path, e$message))
})
}
# Determine a chapter column name robustly
get_chapter_col <- function(df) {
candidates <- c("ICD-10 Chapter (short)", "ICD-10 Chapter", "ICD10_chapter", "chapter", "ICD_chapter")
for (c in candidates) {
if (c %in% names(df)) return(c)
}
return(NA_character_)
}
# Compute a deterministic chapter ordering using the ICD-10 chapter numeral prefix
# e.g., "I. Infectious Diseases", "II. Neoplasms", ..., "XVII. ...", with a fallback for "Death" and unknowns
compute_chapter_levels <- function(chapters) {
ch <- as.character(chapters)
roman_levels <- c(
"I","II","III","IV","V","VI","VII","VIII","IX","X",
"XI","XII","XIII","XIV","XV","XVI","XVII","XVIII","XIX","XX"
)
roman_map <- setNames(seq_along(roman_levels), roman_levels)
# Extract leading Roman numeral before a dot, like "XVI." -> "XVI"
roman <- toupper(gsub("^\\s*([IVXLCDM]+)\\..*$", "\\1", ch))
idx <- rep(NA_integer_, length(ch))
hit <- roman %in% names(roman_map)
idx[hit] <- roman_map[roman[hit]]
# Special-case Death at the end
idx[grepl("^\\s*Death\\b", ch, ignore.case = TRUE)] <- 99L
# Unknowns to the very end
idx[is.na(idx)] <- 100L
# Order chapters by idx, stable within same idx by appearance
o <- order(idx, match(ch, unique(ch)))
unique(ch[o])
}
# Build long-format data.frame with columns: chapter, model, auc
# It will include any of the known model columns that exist in the input df
build_long_df <- function(df) {
model_cols <- c(
auc_120 = "auc_120",
auc_120_l = "auc_120_l",
auc_256 = "auc_256",
auc_256_l = "auc_256_l",
auc_delphi = "auc_delphi"
)
pretty_names <- c(
auc_120 = "GPT-2 120",
auc_120_l = "GPT-2 120_L",
auc_256 = "GPT-2 256",
auc_256_l = "GPT-2 256_L",
auc_delphi = "Delphi"
)
present <- model_cols[names(model_cols) %in% names(df)]
if (length(present) == 0) stop("No known AUC columns found in input data.")
chap_col <- get_chapter_col(df)
if (is.na(chap_col)) {
warning("No chapter column found; using a single 'All' group.")
chapters <- rep("All", nrow(df))
} else {
chapters <- df[[chap_col]]
}
out_list <- list()
for (key in names(model_cols)) {
col <- model_cols[[key]]
if (col %in% names(df)) {
out_list[[length(out_list) + 1]] <- data.frame(
chapter = chapters,
model = pretty_names[[key]],
auc = as.numeric(df[[col]]),
stringsAsFactors = FALSE
)
}
}
long_df <- do.call(rbind, out_list)
# Filter out-of-range or NA
long_df <- long_df[is.finite(long_df$auc) & long_df$auc >= 0 & long_df$auc <= 1, ]
long_df$model <- factor(long_df$model, levels = c("GPT-2 120", "GPT-2 120_L", "GPT-2 256", "GPT-2 256_L", "Delphi"))
return(long_df)
}
# Make the boxplot grouped by chapter
make_boxplot <- function(long_df, title_text, flip = TRUE) {
# Order chapters by their ICD-10 chapter number prefix (Roman numerals)
chap_levels <- compute_chapter_levels(long_df$chapter)
long_df$chapter <- factor(long_df$chapter, levels = chap_levels)
p <- ggplot(long_df, aes(x = chapter, y = auc, fill = model)) +
geom_boxplot(outlier.shape = 19, outlier.size = 0.7, width = 0.75, alpha = 0.95) +
scale_y_continuous(limits = c(0.3, 1.0), breaks = seq(0.3, 1.0, by = 0.1)) +
labs(title = title_text, x = "ICD-10 Chapter", y = "AUC") +
theme_minimal(base_size = 11) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank(),
legend.position = "bottom"
) +
guides(fill = guide_legend(nrow = 1))
if (flip) {
p <- p + coord_flip()
} else {
# For vertical plots, angle x-axis labels for readability
p <- p + theme(axis.text.x = element_text(angle = 45, hjust = 1))
}
p
}
# Build plots for 1-year and no-gap
one_year_df <- read_csv_safe(one_year_csv)
no_gap_df <- read_csv_safe(no_gap_csv)
one_year_long <- build_long_df(one_year_df)
no_gap_long <- build_long_df(no_gap_df)
flip_flag <- ifelse(orientation %in% c("horizontal", "flip", "flipped"), TRUE, FALSE)
p1 <- make_boxplot(one_year_long, "AUC by ICD-10 Chapter (1-year gap)", flip = flip_flag)
p2 <- make_boxplot(no_gap_long, "AUC by ICD-10 Chapter (no gap)", flip = flip_flag)
# Save individual plots
out_1year <- file.path(out_dir, "auc_boxplot_by_chapter_1year.png")
ggsave(out_1year, p1, width = 12, height = 10, dpi = 300, bg = "white")
cat(sprintf("Saved: %s\n", out_1year))
out_nogap <- file.path(out_dir, "auc_boxplot_by_chapter_no_gap.png")
ggsave(out_nogap, p2, width = 12, height = 10, dpi = 300, bg = "white")
cat(sprintf("Saved: %s\n", out_nogap))
# Save a side-by-side grid for quick comparison
grid <- plot_grid(p1, p2, labels = c("A", "B"), ncol = 2, align = "hv")
out_grid <- file.path(out_dir, "auc_boxplot_by_chapter_grid.png")
ggsave(out_grid, grid, width = 18, height = 10, dpi = 250, bg = "white")
cat(sprintf("Saved grid: %s\n", out_grid))

View File

@@ -0,0 +1,125 @@
# Plot AUC comparisons (1-year gap) between models and Delphi using ggplot2
# Usage:
# Rscript plot_model_comparison_1year.R [path_to_csv] [output_dir]
# Defaults:
# path_to_csv = "model_comparison_auc_1year.csv"
# output_dir = current working directory (".")
suppressPackageStartupMessages({
library(ggplot2)
library(cowplot)
})
args <- commandArgs(trailingOnly = TRUE)
csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
out_dir <- if (length(args) >= 2) args[2] else "."
if (!dir.exists(out_dir)) {
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
}
# Read data
# Expect columns including: auc_delphi, auc_256, auc_120, Colour (hex color), name, etc.
df <- tryCatch({
read.csv(csv_path, check.names = FALSE)
}, error = function(e) {
stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message))
})
# Helper to compare any two AUC columns (x vs y)
make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) {
ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) +
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
scale_fill_identity() +
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
coord_fixed(ratio = 1) +
labs(title = title_text, x = x_label, y = y_label) +
theme_minimal(base_size = 10) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank()
)
}
# Helper to compare model AUC vs Delphi AUC (x = auc_delphi)
make_delphi_plot <- function(data, y_col, title_text, y_label) {
ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) +
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
scale_fill_identity() +
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
coord_fixed(ratio = 1) +
labs(title = title_text, x = "AUC_Delphi", y = y_label) +
theme_minimal(base_size = 10) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank()
)
}
# Placeholder empty plot if a required column is missing
empty_plot <- function(msg) {
ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5))
}
# Plot: AUC_120 vs AUC_120_L (1 year gap)
if (!all(c("auc_120", "auc_120_l") %in% names(df))) {
warning("Columns 'auc_120' and/or 'auc_120_l' not found in CSV; skipping AUC_120 vs AUC_120_L plot.")
} else {
p120_vs_120l <- make_xy_plot(
data = df,
x_col = "auc_120",
y_col = "auc_120_l",
title_text = "AUC_120 vs AUC_120_L 1 year gap",
x_label = "AUC_120",
y_label = "AUC_120_L"
)
out_120_vs_120l <- file.path(out_dir, "model_comparison_auc_120_vs_120_l_1year.png")
ggsave(filename = out_120_vs_120l, plot = p120_vs_120l, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_120_vs_120l))
}
# Plot: AUC_256 vs AUC_256_L (1 year gap)
if (!all(c("auc_256", "auc_256_l") %in% names(df))) {
warning("Columns 'auc_256' and/or 'auc_256_l' not found in CSV; skipping AUC_256 vs AUC_256_L plot.")
} else {
p256_vs_256l <- make_xy_plot(
data = df,
x_col = "auc_256",
y_col = "auc_256_l",
title_text = "AUC_256 vs AUC_256_L 1 year gap",
x_label = "AUC_256",
y_label = "AUC_256_L"
)
out_256_vs_256l <- file.path(out_dir, "model_comparison_auc_256_vs_256_l_1year.png")
ggsave(filename = out_256_vs_256l, plot = p256_vs_256l, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_256_vs_256l))
}
# ---- Combined 2x2 grid: (auc_120 vs delphi), (auc_256 vs delphi), (auc_120_l vs delphi), (auc_256_l vs delphi) ----
has_cols <- function(cols) all(cols %in% names(df))
p_120_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120"))) make_delphi_plot(df, "auc_120", "AUC_120 vs Delphi (1 year)", "AUC_120") else empty_plot("Missing auc_120 or auc_delphi")
p_256_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256"))) make_delphi_plot(df, "auc_256", "AUC_256 vs Delphi (1 year)", "AUC_256") else empty_plot("Missing auc_256 or auc_delphi")
p_120l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (1 year)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi")
p_256l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (1 year)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi")
grid_plot <- plot_grid(
p_120_vs_delphi, p_256_vs_delphi,
p_120l_vs_delphi, p_256l_vs_delphi,
labels = c("A", "B", "C", "D"),
ncol = 2, align = "hv"
)
out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_1year_grid.png")
ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white")
cat(sprintf("Saved grid: %s\n", out_grid))

View File

@@ -0,0 +1,128 @@
# Plot AUC comparisons (no gap) between models and Delphi using ggplot2
# Usage:
# Rscript plot_model_comparison_no_gap.R [path_to_csv] [output_dir]
# Defaults:
# path_to_csv = "model_comparison_auc_no_gap.csv"
# output_dir = current working directory (".")
suppressPackageStartupMessages({
library(ggplot2)
library(cowplot)
})
args <- commandArgs(trailingOnly = TRUE)
csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_no_gap.csv"
out_dir <- if (length(args) >= 2) args[2] else "."
if (!dir.exists(out_dir)) {
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
}
# Read data
# Expect columns including: auc_delphi, auc_256, auc_120, auc_256_l, auc_120_l, Colour (hex color), name, etc.
df <- tryCatch({
read.csv(csv_path, check.names = FALSE)
}, error = function(e) {
stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message))
})
# Helper to compare any two AUC columns (x vs y)
make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) {
ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) +
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
scale_fill_identity() +
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
coord_fixed(ratio = 1) +
labs(title = title_text, x = x_label, y = y_label) +
theme_minimal(base_size = 10) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank()
)
}
# Helper to compare model AUC vs Delphi AUC (x = auc_delphi)
make_delphi_plot <- function(data, y_col, title_text, y_label) {
ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) +
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
scale_fill_identity() +
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
coord_fixed(ratio = 1) +
labs(title = title_text, x = "AUC_Delphi", y = y_label) +
theme_minimal(base_size = 10) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank()
)
}
# Placeholder empty plot if a required column is missing
empty_plot <- function(msg) {
ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5))
}
# Individual Delphi comparison plots
has_cols <- function(cols) all(cols %in% names(df))
# AUC_120 vs AUC_Delphi (no gap)
if (has_cols(c("auc_delphi", "auc_120"))) {
p120 <- make_delphi_plot(df, "auc_120", "AUC_120 vs AUC_Delphi (no gap)", "AUC_120")
out_120 <- file.path(out_dir, "fig_auc_120_vs_delphi_no_gap.png")
ggsave(filename = out_120, plot = p120, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_120))
} else {
warning("Missing columns for AUC_120 vs Delphi plot.")
}
# AUC_256 vs AUC_Delphi (no gap)
if (has_cols(c("auc_delphi", "auc_256"))) {
p256 <- make_delphi_plot(df, "auc_256", "AUC_256 vs AUC_Delphi (no gap)", "AUC_256")
out_256 <- file.path(out_dir, "model_comparison_auc_256_vs_delphi_no_gap.png")
ggsave(filename = out_256, plot = p256, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_256))
} else {
warning("Missing columns for AUC_256 vs Delphi plot.")
}
# AUC_120_L vs AUC_Delphi (no gap)
if (has_cols(c("auc_delphi", "auc_120_l"))) {
p120l <- make_delphi_plot(df, "auc_120_l", "AUC_120_L vs AUC_Delphi (no gap)", "AUC_120_L")
out_120l <- file.path(out_dir, "fig_auc_120_l_vs_delphi_no_gap.png")
ggsave(filename = out_120l, plot = p120l, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_120l))
} else {
warning("Missing columns for AUC_120_L vs Delphi plot.")
}
# AUC_256_L vs AUC_Delphi (no gap)
if (has_cols(c("auc_delphi", "auc_256_l"))) {
p256l <- make_delphi_plot(df, "auc_256_l", "AUC_256_L vs AUC_Delphi (no gap)", "AUC_256_L")
out_256l <- file.path(out_dir, "model_comparison_auc_256_l_vs_delphi_no_gap.png")
ggsave(filename = out_256l, plot = p256l, width = 7, height = 4, dpi = 600, bg = "white")
cat(sprintf("Saved: %s\n", out_256l))
} else {
warning("Missing columns for AUC_256_L vs Delphi plot.")
}
# 2x2 grid of Delphi comparisons
p_120_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120"))) make_delphi_plot(df, "auc_120", "AUC_120 vs Delphi (no gap)", "AUC_120") else empty_plot("Missing auc_120 or auc_delphi")
p_256_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256"))) make_delphi_plot(df, "auc_256", "AUC_256 vs Delphi (no gap)", "AUC_256") else empty_plot("Missing auc_256 or auc_delphi")
p_120l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (no gap)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi")
p_256l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (no gap)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi")
grid_plot <- plot_grid(
p_120_vs_delphi, p_256_vs_delphi,
p_120l_vs_delphi, p_256l_vs_delphi,
labels = c("A", "B", "C", "D"),
ncol = 2, align = "hv"
)
out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_no_gap_grid.png")
ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white")
cat(sprintf("Saved grid: %s\n", out_grid))

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
torch
numpy
tqdm
matplotlib
joblib

View File

@@ -1,13 +1,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import AdamW
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
import math import math
import tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
# --- Configuration --- # --- Configuration ---
@@ -15,33 +17,80 @@ class TrainConfig:
# Data parameters # Data parameters
train_data_path = 'ukb_real_train.bin' train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin' val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length block_length = 48 # Sequence length
# Model parameters # Model parameters
n_embd = 256 n_embd = 120
n_layer = 8 n_layer = 12
n_head = 8 n_head = 12
pdrop = 0.1 pdrop = 0.1
token_pdrop = 0.1 token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters # Training parameters
max_epoch = 200 max_epoch = 200
batch_size = 128 batch_size = 128
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1] ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters # System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script --- # --- Main Training Script ---
def main(): def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.')
args = parser.parse_args()
config = TrainConfig() config = TrainConfig()
config.n_layer = args.n_layer
config.n_embd = args.n_embd
config.n_head = args.n_head
config.max_epoch = args.max_epoch
config.batch_size = args.batch_size
config.lr_initial = args.lr_initial
config.lr_final = args.lr_final
config.weight_decay = args.weight_decay
config.warmup_epochs = args.warmup_epochs
config.early_stopping_patience = args.early_stopping_patience
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
config.model_name = args.model
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
# --- 0. Save Configuration ---
# Include model class in config filename for clarity/distinction across architectures
config_filename = f"config_{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading --- # --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -60,7 +109,13 @@ def main():
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}[config.model_name]
model = model_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
n_layer=config.n_layer, n_layer=config.n_layer,
@@ -72,7 +127,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
@@ -170,7 +225,7 @@ def main():
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt') torch.save(model.state_dict(), checkpoint_filename)
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
@@ -183,12 +238,20 @@ def main():
# --- Save Best Model at the End --- # --- Save Best Model at the End ---
if best_val_loss != float('inf'): if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt')) model.load_state_dict(torch.load(checkpoint_filename))
print("Saving final best model to best_model.pt") print(f"Saving final best model to {model_filename}")
torch.save(model.state_dict(), 'best_model.pt') torch.save(model.state_dict(), model_filename)
else: else:
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file ---
losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# --- Plot and Save Loss Curves --- # --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total) num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1) epochs = range(1, num_epochs + 1)

364
train_ddp.py Normal file
View File

@@ -0,0 +1,364 @@
import os
import json
import math
import argparse
from typing import Tuple
import torch
import torch.distributed as dist
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from utils import PatientEventDataset
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10
early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters (ignored tokens)
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
def setup_ddp(backend: str | None = None):
"""Initialize torch.distributed from environment variables set by torchrun."""
if backend is None:
if torch.cuda.is_available() and os.name != 'nt':
backend = 'nccl'
else:
backend = 'gloo'
dist.init_process_group(backend=backend)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device(f'cuda:{local_rank}')
else:
device = torch.device('cpu')
return rank, world_size, local_rank, device
def cleanup_ddp():
if dist.is_initialized():
dist.destroy_process_group()
def cosine_lr(epoch: int, cfg: TrainConfig) -> float:
if epoch < cfg.warmup_epochs:
return cfg.lr_initial
progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs))
return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress))
def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor:
"""All-reduce sum then divide by world_size."""
value = value.clone().to(torch.float64)
dist.all_reduce(value, op=dist.ReduceOp.SUM)
value /= world_size
return value.to(torch.float32)
def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.')
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--n_embd', type=int, default=120)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr_initial', type=float, default=6e-4)
parser.add_argument('--lr_final', type=float, default=6e-5)
parser.add_argument('--weight_decay', type=float, default=2e-1)
parser.add_argument('--warmup_epochs', type=int, default=10)
parser.add_argument('--early_stopping_patience', type=int, default=10)
parser.add_argument('--pdrop', type=float, default=0.1)
parser.add_argument('--token_pdrop', type=float, default=0.1)
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99])
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2')
parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.')
args = parser.parse_args()
rank, world_size, local_rank, device = setup_ddp(args.backend)
# Build config
cfg = TrainConfig()
cfg.n_layer = args.n_layer
cfg.n_embd = args.n_embd
cfg.n_head = args.n_head
cfg.max_epoch = args.max_epoch
cfg.batch_size = args.batch_size
cfg.lr_initial = args.lr_initial
cfg.lr_final = args.lr_final
cfg.weight_decay = args.weight_decay
cfg.warmup_epochs = args.warmup_epochs
cfg.early_stopping_patience = args.early_stopping_patience
cfg.pdrop = args.pdrop
cfg.token_pdrop = args.token_pdrop
cfg.betas = tuple(args.betas)
cfg.model_name = args.model
# Filenames (shared across ranks)
model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}"
model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json"
# Save config only on rank 0
if rank == 0:
with open(config_filename, 'w') as f:
json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4)
print(f"[rank 0] Configuration saved to {config_filename}")
# Load data (all ranks)
if rank == 0:
print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...")
train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if rank == 0:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, cfg.block_length)
val_dataset = PatientEventDataset(val_data_arr, cfg.block_length)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
# Model, loss, optimizer
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}[cfg.model_name]
model = model_cls(
vocab_size=vocab_size,
n_embd=cfg.n_embd,
n_layer=cfg.n_layer,
n_head=cfg.n_head,
pdrop=cfg.pdrop,
token_pdrop=cfg.token_pdrop,
).to(device)
ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None)
loss_fn = CombinedLoss(cfg.ignored_token_ids)
optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas)
best_val_loss = float('inf')
patience_counter = 0
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if rank == 0:
print("Starting DDP training...")
for epoch in range(cfg.max_epoch):
# Update sampler epoch for shuffling
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
# Set LR
lr = cosine_lr(epoch, cfg)
for pg in optimizer.param_groups:
pg['lr'] = lr
# Train
ddp_model.train()
train_loss_ce_acc = torch.zeros(1, device=device)
train_loss_surv_acc = torch.zeros(1, device=device)
train_steps = 0
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0))
for event_seq, time_seq in pbar:
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = ddp_model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
train_loss_ce_acc += loss_ce.detach()
train_loss_surv_acc += loss_survival.detach()
train_steps += 1
if rank == 0:
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
# Aggregate train losses across ranks
if train_steps == 0:
train_steps = 1
steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64)
dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM)
train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size)
train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size)
if rank == 0:
train_losses_ce.append(train_loss_ce_mean.item())
train_losses_surv.append(train_loss_surv_mean.item())
train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item())
# Validation
ddp_model.eval()
val_loss_ce_acc = torch.zeros(1, device=device)
val_loss_surv_acc = torch.zeros(1, device=device)
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0))
for event_seq, time_seq in pbar_val:
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = ddp_model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.detach()
val_loss_surv_acc += loss_survival.detach()
val_steps += 1
if val_steps == 0:
val_steps = 1
vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64)
dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM)
val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size)
val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size)
total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item()
if rank == 0:
val_losses_ce.append(val_loss_ce_mean.item())
val_losses_surv.append(val_loss_surv_mean.item())
val_losses_total.append(total_val_loss)
print(
f"Epoch {epoch+1} Summary:\n"
f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n"
f" Learning Rate: {lr:.6f}"
)
# Early stopping on rank 0; broadcast decision
improved = total_val_loss < best_val_loss
if improved:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(ddp_model.module.state_dict(), checkpoint_filename)
else:
if epoch >= cfg.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}")
stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device)
else:
stop_flag = torch.zeros(1, device=device)
# Broadcast stop flag and best loss to all ranks
dist.broadcast(stop_flag, src=0)
if stop_flag.item() > 0:
if rank == 0:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
# Save best model at the end (rank 0)
if rank == 0 and best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
state = torch.load(checkpoint_filename, map_location='cpu')
ddp_model.module.load_state_dict(state)
print(f"Saving final best model to {model_filename}")
torch.save(ddp_model.module.state_dict(), model_filename)
# Save losses to file
losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot curves
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
cleanup_ddp()
if __name__ == '__main__':
main()

218
train_iter.py Normal file
View File

@@ -0,0 +1,218 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.0
token_pdrop = 0.0
# Training parameters
max_iter = 200000
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
# --- 3. Training Loop ---
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
print("\nTraining finished.")
# --- 4. Final Validation ---
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter.png')
print("\nLoss curves saved to loss_curves_iter.png")
if __name__ == '__main__':
main()

163
utils.py
View File

@@ -1,7 +1,11 @@
import os
import torch import torch
import numpy as np import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
""" """
@@ -39,17 +43,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
""" """
return len(self.patient_ids) return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx):
""" """
Retrieves, processes, and returns a single patient's event sequence. Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args: Args:
idx (int): The index of the patient to retrieve. idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns: Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence), If idx is an int, a tuple of two torch.long tensors:
both of shape (block_length,). (event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
""" """
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort # 1. Retrieve and Sort
patient_id = self.patient_ids[idx] patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -102,3 +111,147 @@ class PatientEventDataset(torch.utils.data.Dataset):
time_tensor = torch.tensor(time_stamps, dtype=torch.long) time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path: str, device: str = 'cpu'):
"""
Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
device: 'cpu' or 'cuda'.
Returns:
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
}.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# Newer naming (includes model_name) used by train.py when model_name is present
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else:
candidate = None
for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate
# 5) Build model from config (tolerant to missing fields)
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# 6) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
model.eval()
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times