Compare commits
47 Commits
c7296381b8
...
main
Author | SHA1 | Date | |
---|---|---|---|
02be03f784 | |||
4d1fc63667 | |||
dd58ced9b9 | |||
3bef72f50b | |||
a81da36657 | |||
b954b4b3e7 | |||
f8e0104d6b | |||
262a7db0da | |||
9917b3ab63 | |||
8316326d7e | |||
6dd5eb95c7 | |||
5b0642eb6e | |||
93cf2018d2 | |||
6801e5bdbb | |||
92a5bd4a83 | |||
dfdf64da9a | |||
bd88daa8c2 | |||
e348086e52 | |||
a8aa5a2bd6 | |||
ddb7dbfc67 | |||
88cccdad2e | |||
8f44018bae | |||
1c9e2a2fb3 | |||
6b782b86e1 | |||
9a9de170d1 | |||
7e57e5d3b1 | |||
14865ac5b6 | |||
dbc3000192 | |||
082c719975 | |||
a631ac6d59 | |||
f7356b183c | |||
3390bc025e | |||
a832a45c62 | |||
d760c45baf | |||
053f86f4da | |||
d4d25ac9c7 | |||
fe0304a96a | |||
7e8d8d307b | |||
fc0aef4e71 | |||
02d84a7eca | |||
cb7575a229 | |||
e2495f43b0 | |||
6e0713048a | |||
eec406d79f | |||
e3e533c9ec | |||
b5172392cb | |||
6b0b86d9d0 |
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# IDE settings
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Python cache
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Model checkpoints
|
||||||
|
*.pt
|
||||||
|
|
||||||
|
# Large data files
|
||||||
|
ukb_delphi.txt
|
||||||
|
ukb_real.bin
|
||||||
|
|
||||||
|
# Small data files
|
||||||
|
fields.txt
|
||||||
|
icd10_codes_mod.tsv
|
||||||
|
labels.csv
|
BIN
best_model_n_embd_120_n_layer_12_n_head_12.pt
Normal file
BIN
best_model_n_embd_120_n_layer_12_n_head_12.pt
Normal file
Binary file not shown.
BIN
best_model_n_embd_256_n_layer_16_n_head_16.pt
Normal file
BIN
best_model_n_embd_256_n_layer_16_n_head_16.pt
Normal file
Binary file not shown.
19
config_n_embd_120_n_layer_12_n_head_12.json
Normal file
19
config_n_embd_120_n_layer_12_n_head_12.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"model_name": "TimeAwareGPT2",
|
||||||
|
"n_layer": 12,
|
||||||
|
"n_embd": 120,
|
||||||
|
"n_head": 12,
|
||||||
|
"max_epoch": 200,
|
||||||
|
"batch_size": 128,
|
||||||
|
"lr_initial": 0.0006,
|
||||||
|
"lr_final": 6e-05,
|
||||||
|
"weight_decay": 0.2,
|
||||||
|
"warmup_epochs": 10,
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
"pdrop": 0.0,
|
||||||
|
"token_pdrop": 0.0,
|
||||||
|
"betas": [
|
||||||
|
0.9,
|
||||||
|
0.99
|
||||||
|
]
|
||||||
|
}
|
19
config_n_embd_256_n_layer_16_n_head_16.json
Normal file
19
config_n_embd_256_n_layer_16_n_head_16.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"model_name": "TimeAwareGPT2",
|
||||||
|
"n_layer": 16,
|
||||||
|
"n_embd": 256,
|
||||||
|
"n_head": 16,
|
||||||
|
"max_epoch": 200,
|
||||||
|
"batch_size": 128,
|
||||||
|
"lr_initial": 0.0006,
|
||||||
|
"lr_final": 6e-05,
|
||||||
|
"weight_decay": 0.2,
|
||||||
|
"warmup_epochs": 10,
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
"pdrop": 0.0,
|
||||||
|
"token_pdrop": 0.0,
|
||||||
|
"betas": [
|
||||||
|
0.9,
|
||||||
|
0.99
|
||||||
|
]
|
||||||
|
}
|
1271
delphi_labels_chapters_colours_icd.csv
Normal file
1271
delphi_labels_chapters_colours_icd.csv
Normal file
File diff suppressed because it is too large
Load Diff
496
evaluate_auc.py
Normal file
496
evaluate_auc.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
from utils import load_model, get_batch, PatientEventDataset
|
||||||
|
from pathlib import Path
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
|
|
||||||
|
def auc(x1, x2):
|
||||||
|
n1 = len(x1)
|
||||||
|
n2 = len(x2)
|
||||||
|
R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1
|
||||||
|
U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1
|
||||||
|
if n1 == 0 or n2 == 0:
|
||||||
|
return np.nan
|
||||||
|
return U1 / n1 / n2
|
||||||
|
|
||||||
|
|
||||||
|
def get_common_diseases(delphi_labels, filter_min_total=100):
|
||||||
|
chapters_of_interest = [
|
||||||
|
"I. Infectious Diseases",
|
||||||
|
"II. Neoplasms",
|
||||||
|
"III. Blood & Immune Disorders",
|
||||||
|
"IV. Metabolic Diseases",
|
||||||
|
"V. Mental Disorders",
|
||||||
|
"VI. Nervous System Diseases",
|
||||||
|
"VII. Eye Diseases",
|
||||||
|
"VIII. Ear Diseases",
|
||||||
|
"IX. Circulatory Diseases",
|
||||||
|
"X. Respiratory Diseases",
|
||||||
|
"XI. Digestive Diseases",
|
||||||
|
"XII. Skin Diseases",
|
||||||
|
"XIII. Musculoskeletal Diseases",
|
||||||
|
"XIV. Genitourinary Diseases",
|
||||||
|
"XV. Pregnancy & Childbirth",
|
||||||
|
"XVI. Perinatal Conditions",
|
||||||
|
"XVII. Congenital Abnormalities",
|
||||||
|
"Death",
|
||||||
|
]
|
||||||
|
labels_df = delphi_labels[
|
||||||
|
delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total)
|
||||||
|
]
|
||||||
|
return labels_df["index"].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000):
|
||||||
|
"""
|
||||||
|
Computes bootstrapped AUC estimates using PyTorch on CUDA.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
case: 1D tensor of scores for positive cases
|
||||||
|
control: 1D tensor of scores for controls
|
||||||
|
n_bootstrap: Number of bootstrap replicates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA is not available. This function requires a GPU.")
|
||||||
|
|
||||||
|
# Convert inputs to CUDA tensors
|
||||||
|
if not torch.is_tensor(case):
|
||||||
|
case = torch.tensor(case, device="cuda", dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
case = case.to("cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
if not torch.is_tensor(control):
|
||||||
|
control = torch.tensor(control, device="cuda", dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
control = control.to("cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
n_case = case.size(0)
|
||||||
|
n_control = control.size(0)
|
||||||
|
total = n_case + n_control
|
||||||
|
|
||||||
|
# Generate bootstrap samples
|
||||||
|
boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda")
|
||||||
|
boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda")
|
||||||
|
|
||||||
|
boot_case = case[boot_idx_case]
|
||||||
|
boot_control = control[boot_idx_control]
|
||||||
|
|
||||||
|
combined = torch.cat([boot_case, boot_control], dim=1)
|
||||||
|
|
||||||
|
# Mask to identify case entries
|
||||||
|
mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda")
|
||||||
|
mask[:, :n_case] = True
|
||||||
|
|
||||||
|
# Compute ranks and AUC
|
||||||
|
ranks = combined.argsort(dim=1).argsort(dim=1)
|
||||||
|
case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1)
|
||||||
|
min_case_rank_sum = n_case * (n_case - 1) / 2.0
|
||||||
|
U = case_ranks_sum - min_case_rank_sum
|
||||||
|
aucs = U / (n_case * n_control)
|
||||||
|
return aucs.cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# AUC comparison adapted from
|
||||||
|
# https://github.com/Netflix/vmaf/
|
||||||
|
def compute_midrank(x):
|
||||||
|
"""Computes midranks.
|
||||||
|
Args:
|
||||||
|
x - a 1D numpy array
|
||||||
|
Returns:
|
||||||
|
array of midranks
|
||||||
|
"""
|
||||||
|
J = np.argsort(x)
|
||||||
|
Z = x[J]
|
||||||
|
N = len(x)
|
||||||
|
T = np.zeros(N, dtype=np.float32)
|
||||||
|
i = 0
|
||||||
|
while i < N:
|
||||||
|
j = i
|
||||||
|
while j < N and Z[j] == Z[i]:
|
||||||
|
j += 1
|
||||||
|
T[i:j] = 0.5 * (i + j - 1)
|
||||||
|
i = j
|
||||||
|
T2 = np.empty(N, dtype=np.float32)
|
||||||
|
# Note(kazeevn) +1 is due to Python using 0-based indexing
|
||||||
|
# instead of 1-based in the AUC formula in the paper
|
||||||
|
T2[J] = T + 1
|
||||||
|
return T2
|
||||||
|
|
||||||
|
|
||||||
|
def fastDeLong(predictions_sorted_transposed, label_1_count):
|
||||||
|
"""
|
||||||
|
The fast version of DeLong's method for computing the covariance of
|
||||||
|
unadjusted AUC.
|
||||||
|
Args:
|
||||||
|
predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
|
||||||
|
sorted such as the examples with label "1" are first
|
||||||
|
Returns:
|
||||||
|
(AUC value, DeLong covariance)
|
||||||
|
Reference:
|
||||||
|
@article{sun2014fast,
|
||||||
|
title={Fast Implementation of DeLong's Algorithm for
|
||||||
|
Comparing the Areas Under Correlated Receiver Operating Characteristic Curves},
|
||||||
|
author={Xu Sun and Weichao Xu},
|
||||||
|
journal={IEEE Signal Processing Letters},
|
||||||
|
volume={21},
|
||||||
|
number={11},
|
||||||
|
pages={1389--1393},
|
||||||
|
year={2014},
|
||||||
|
publisher={IEEE}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Short variables are named as they are in the paper
|
||||||
|
m = label_1_count
|
||||||
|
n = predictions_sorted_transposed.shape[1] - m
|
||||||
|
positive_examples = predictions_sorted_transposed[:, :m]
|
||||||
|
negative_examples = predictions_sorted_transposed[:, m:]
|
||||||
|
k = predictions_sorted_transposed.shape[0]
|
||||||
|
|
||||||
|
tx = np.empty([k, m], dtype=np.float32)
|
||||||
|
ty = np.empty([k, n], dtype=np.float32)
|
||||||
|
tz = np.empty([k, m + n], dtype=np.float32)
|
||||||
|
for r in range(k):
|
||||||
|
tx[r, :] = compute_midrank(positive_examples[r, :])
|
||||||
|
ty[r, :] = compute_midrank(negative_examples[r, :])
|
||||||
|
tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
|
||||||
|
aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
|
||||||
|
v01 = (tz[:, :m] - tx[:, :]) / n
|
||||||
|
v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
|
||||||
|
sx = np.cov(v01)
|
||||||
|
sy = np.cov(v10)
|
||||||
|
delongcov = sx / m + sy / n
|
||||||
|
return aucs, delongcov
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ground_truth_statistics(ground_truth):
|
||||||
|
assert np.array_equal(np.unique(ground_truth), [0, 1])
|
||||||
|
order = (-ground_truth).argsort()
|
||||||
|
label_1_count = int(ground_truth.sum())
|
||||||
|
return order, label_1_count
|
||||||
|
|
||||||
|
|
||||||
|
def get_auc_delong_var(healthy_scores, diseased_scores):
|
||||||
|
"""
|
||||||
|
Computes ROC AUC value and variance using DeLong's method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
healthy_scores: Values for class 0 (healthy/controls)
|
||||||
|
diseased_scores: Values for class 1 (diseased/cases)
|
||||||
|
Returns:
|
||||||
|
AUC value and variance
|
||||||
|
"""
|
||||||
|
# Create ground truth labels (1 for diseased, 0 for healthy)
|
||||||
|
ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores))
|
||||||
|
predictions = np.concatenate([diseased_scores, healthy_scores])
|
||||||
|
|
||||||
|
# Compute statistics needed for DeLong method
|
||||||
|
order, label_1_count = compute_ground_truth_statistics(ground_truth)
|
||||||
|
predictions_sorted_transposed = predictions[np.newaxis, order]
|
||||||
|
|
||||||
|
# Calculate AUC and covariance
|
||||||
|
aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count)
|
||||||
|
assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
|
||||||
|
|
||||||
|
return aucs[0], delongcov
|
||||||
|
|
||||||
|
|
||||||
|
def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False):
|
||||||
|
age_step = age_groups[1] - age_groups[0]
|
||||||
|
|
||||||
|
# Indexes of cases with disease k
|
||||||
|
wk = np.where(d[2] == k)
|
||||||
|
|
||||||
|
if len(wk[0]) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For controls, we need to exclude cases with disease k
|
||||||
|
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
|
||||||
|
|
||||||
|
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
|
||||||
|
|
||||||
|
# We need to take into account the offset t and use the tokens for prediction that are at least t before the event
|
||||||
|
if precomputed_idx is None:
|
||||||
|
pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1
|
||||||
|
else:
|
||||||
|
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
|
||||||
|
|
||||||
|
valid_indices = pred_idx != -1
|
||||||
|
pred_idx = pred_idx[valid_indices]
|
||||||
|
wall = (wall[0][valid_indices], wall[1][valid_indices])
|
||||||
|
|
||||||
|
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
|
||||||
|
zk = d[3][wall] # Target times
|
||||||
|
|
||||||
|
x = p[..., j][(wall[0], pred_idx)]
|
||||||
|
|
||||||
|
p_idx = wall[0]
|
||||||
|
|
||||||
|
out = []
|
||||||
|
|
||||||
|
for i, aa in enumerate(age_groups):
|
||||||
|
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
|
||||||
|
|
||||||
|
if not np.any(a):
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected_groups = p_idx[a]
|
||||||
|
_, unique_indices = np.unique(selected_groups, return_index=True)
|
||||||
|
|
||||||
|
a_filtered = a[a]
|
||||||
|
a_filtered[:] = False
|
||||||
|
a_filtered[unique_indices] = True
|
||||||
|
a[a] = a_filtered
|
||||||
|
|
||||||
|
is_case = np.zeros_like(x, dtype=bool)
|
||||||
|
is_case[:len(wk[0])] = True
|
||||||
|
|
||||||
|
control = x[~is_case & a]
|
||||||
|
case = x[is_case & a]
|
||||||
|
|
||||||
|
if len(control) == 0 or len(case) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if use_delong:
|
||||||
|
auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case)
|
||||||
|
auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong}
|
||||||
|
else:
|
||||||
|
auc_delong_dict = {}
|
||||||
|
|
||||||
|
if n_bootstrap > 1:
|
||||||
|
aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap)
|
||||||
|
|
||||||
|
for bootstrap_idx in range(n_bootstrap):
|
||||||
|
y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx]
|
||||||
|
out_item = {
|
||||||
|
"token": k,
|
||||||
|
"auc": y,
|
||||||
|
"age": aa,
|
||||||
|
"n_healthy": len(control),
|
||||||
|
"n_diseased": len(case),
|
||||||
|
}
|
||||||
|
out.append(out_item | auc_delong_dict)
|
||||||
|
if n_bootstrap > 1:
|
||||||
|
out_item["bootstrap_idx"] = bootstrap_idx
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
|
||||||
|
all_aucs = []
|
||||||
|
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
||||||
|
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
||||||
|
p_sex = p100k[sex_mask]
|
||||||
|
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
|
||||||
|
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
||||||
|
for j, k in tqdm(
|
||||||
|
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
||||||
|
):
|
||||||
|
out = get_calibration_auc(
|
||||||
|
j,
|
||||||
|
k,
|
||||||
|
d100k_sex,
|
||||||
|
p_sex,
|
||||||
|
age_groups=age_groups,
|
||||||
|
offset=offset,
|
||||||
|
precomputed_idx=precomputed_idx_subset,
|
||||||
|
n_bootstrap=n_bootstrap,
|
||||||
|
use_delong=True,
|
||||||
|
)
|
||||||
|
if out is None:
|
||||||
|
continue
|
||||||
|
for out_item in out:
|
||||||
|
out_item["sex"] = sex
|
||||||
|
all_aucs.append(out_item)
|
||||||
|
return all_aucs
|
||||||
|
|
||||||
|
|
||||||
|
# New internal function that performs the AUC evaluation pipeline.
|
||||||
|
def evaluate_auc_pipeline(
|
||||||
|
model,
|
||||||
|
d100k,
|
||||||
|
output_path,
|
||||||
|
delphi_labels,
|
||||||
|
diseases_of_interest=None,
|
||||||
|
filter_min_total=100,
|
||||||
|
disease_chunk_size=200,
|
||||||
|
age_groups=np.arange(40, 80, 5),
|
||||||
|
offset=0.1,
|
||||||
|
batch_size=256,
|
||||||
|
device="cpu",
|
||||||
|
seed=1337,
|
||||||
|
n_bootstrap=1,
|
||||||
|
meta_info={},
|
||||||
|
n_jobs=-1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs the AUC evaluation pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The loaded model set to eval().
|
||||||
|
d100k (tuple): Data batch from get_batch.
|
||||||
|
delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv").
|
||||||
|
output_path (str | None): Directory where CSV files will be written. If None, files will not be saved.
|
||||||
|
diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used.
|
||||||
|
filter_min_total (int): Minimum total token count to include a token.
|
||||||
|
disease_chunk_size (int): Maximum chunk size for processing diseases.
|
||||||
|
age_groups (np.ndarray): Age groups to use in calibration.
|
||||||
|
offset (float): Offset used in get_calibration_auc.
|
||||||
|
batch_size (int): Batch size for model forwarding.
|
||||||
|
device (str): Device identifier.
|
||||||
|
seed (int): Random seed for reproducibility.
|
||||||
|
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
|
||||||
|
n_jobs (int): Number of parallel jobs to run.
|
||||||
|
Returns:
|
||||||
|
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert n_bootstrap > 0, "n_bootstrap must be greater than 0"
|
||||||
|
|
||||||
|
# Set random seeds
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Get common diseases
|
||||||
|
if diseases_of_interest is None:
|
||||||
|
diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total)
|
||||||
|
|
||||||
|
# Split diseases into chunks for processing
|
||||||
|
num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size
|
||||||
|
diseases_chunks = np.array_split(diseases_of_interest, num_chunks)
|
||||||
|
|
||||||
|
# Precompute prediction indices for calibration
|
||||||
|
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
|
||||||
|
|
||||||
|
p100k = []
|
||||||
|
model.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
for dd in tqdm(
|
||||||
|
zip(*[torch.split(x, batch_size) for x in d100k]),
|
||||||
|
desc=f"Model inference",
|
||||||
|
total=d100k[0].shape[0] // batch_size + 1,
|
||||||
|
):
|
||||||
|
dd = [x.to(device) for x in dd]
|
||||||
|
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||||
|
p100k.append(outputs.astype("float16"))
|
||||||
|
p100k = np.vstack(p100k)
|
||||||
|
|
||||||
|
results = Parallel(n_jobs=n_jobs)(
|
||||||
|
delayed(process_chunk)(
|
||||||
|
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
|
||||||
|
)
|
||||||
|
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_aucs = [item for sublist in results for item in sublist]
|
||||||
|
|
||||||
|
df_auc_unpooled = pd.DataFrame(all_aucs)
|
||||||
|
|
||||||
|
for key, value in meta_info.items():
|
||||||
|
df_auc_unpooled[key] = value
|
||||||
|
|
||||||
|
delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']]
|
||||||
|
df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner")
|
||||||
|
|
||||||
|
def aggregate_age_brackets_delong(group):
|
||||||
|
# For normal distributions, when averaging n of them:
|
||||||
|
# The variance of the sum is the sum of variances
|
||||||
|
# The variance of the average is the sum of variances divided by n^2
|
||||||
|
n = len(group)
|
||||||
|
mean = group['auc_delong'].mean()
|
||||||
|
# Since we're taking the average, divide combined variance by n^2
|
||||||
|
var = group['auc_variance_delong'].sum() / (n**2)
|
||||||
|
return pd.Series({
|
||||||
|
'auc': mean,
|
||||||
|
'auc_variance_delong': var,
|
||||||
|
'n_samples': n,
|
||||||
|
'n_diseased': group['n_diseased'].sum(),
|
||||||
|
'n_healthy': group['n_healthy'].sum(),
|
||||||
|
})
|
||||||
|
|
||||||
|
print('Using DeLong method to calculate AUC confidence intervals..')
|
||||||
|
|
||||||
|
df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index()
|
||||||
|
df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner")
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||||
|
df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False)
|
||||||
|
df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False)
|
||||||
|
|
||||||
|
return df_auc_unpooled_merged, df_auc_merged
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
||||||
|
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
|
||||||
|
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
||||||
|
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
||||||
|
parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction")
|
||||||
|
# Optional filtering/chunking parameters:
|
||||||
|
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
|
||||||
|
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||||
|
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_name = args.model_name
|
||||||
|
output_path = f'auc_evaluation_{model_name}'
|
||||||
|
dataset_subset_size = args.dataset_subset_size
|
||||||
|
offset = args.offset
|
||||||
|
|
||||||
|
# Create output folder if it doesn't exist.
|
||||||
|
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
seed = 1337
|
||||||
|
|
||||||
|
# Load model checkpoint and initialize model.
|
||||||
|
model = load_model(
|
||||||
|
config_path=f'config_{model_name}.json',
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# Load training and validation data.
|
||||||
|
|
||||||
|
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
|
||||||
|
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
block_length = 128
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, block_length)
|
||||||
|
|
||||||
|
if dataset_subset_size == -1:
|
||||||
|
dataset_subset_size = len(val_dataset)
|
||||||
|
|
||||||
|
# Get a subset batch for evaluation.
|
||||||
|
d100k = get_batch(val_dataset, slice(dataset_subset_size))
|
||||||
|
|
||||||
|
# Load labels (external) to be passed in.
|
||||||
|
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
|
||||||
|
|
||||||
|
# Call the internal evaluation function.
|
||||||
|
df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline(
|
||||||
|
model,
|
||||||
|
d100k,
|
||||||
|
output_path,
|
||||||
|
delphi_labels,
|
||||||
|
diseases_of_interest=None,
|
||||||
|
filter_min_total=args.filter_min_total,
|
||||||
|
disease_chunk_size=args.disease_chunk_size,
|
||||||
|
device=device,
|
||||||
|
seed=seed,
|
||||||
|
offset=offset,
|
||||||
|
n_bootstrap=args.n_bootstrap,
|
||||||
|
n_jobs=args.n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
718
evaluate_models.ipynb
Normal file
718
evaluate_models.ipynb
Normal file
File diff suppressed because one or more lines are too long
565
models.py
565
models.py
@@ -1,7 +1,83 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 1. Component Modules (Building Blocks)
|
||||||
|
# =============================================================================
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.pad = kernel_size - 1
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
channels, channels, kernel_size,
|
||||||
|
padding=0, groups=groups
|
||||||
|
)
|
||||||
|
def forward(self, x): # x: (B, C, L)
|
||||||
|
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
||||||
|
x = x.contiguous()
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
||||||
|
def __init__(self, d_model, kernel_size=5, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.dw = CausalConv1d(d_model, kernel_size, groups=d_model) # depthwise
|
||||||
|
self.pw = nn.Conv1d(d_model, d_model, 1) # pointwise
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.ln = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x): # x: (B, L, D)
|
||||||
|
y = x.transpose(1, 2).contiguous() # (B, D, L)
|
||||||
|
y = self.dw(y) # (B, D, L)
|
||||||
|
y = self.pw(y.contiguous()) # (B, D, L)
|
||||||
|
y = y.transpose(1, 2).contiguous() # (B, L, D)
|
||||||
|
y = self.act(y)
|
||||||
|
y = self.dropout(y)
|
||||||
|
return self.ln(x + y) # residual connection + layer norm (LN)
|
||||||
|
|
||||||
|
class TimeFeatureProjector(nn.Module):
|
||||||
|
"""
|
||||||
|
Projects scalar time t and its increment Δt into d_model dimensions.
|
||||||
|
Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features).
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model, fourier_dim=32, dt_clip=1e6):
|
||||||
|
super().__init__()
|
||||||
|
self.dt_clip = dt_clip
|
||||||
|
self.scalar_proj = nn.Linear(2, d_model, bias=False) # [t_scaled, dt_scaled] -> D
|
||||||
|
|
||||||
|
# Predefine a set of logarithmically spaced frequencies (tune for your time units if needed)
|
||||||
|
k = fourier_dim // 2
|
||||||
|
freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi # frequency coverage ~1e-4 to 1e2
|
||||||
|
self.register_buffer("freqs", freqs, persistent=False)
|
||||||
|
|
||||||
|
self.fourier_proj = nn.Linear(2*k, d_model, bias=False) # [sin, cos] -> D
|
||||||
|
self.gate = nn.Parameter(torch.zeros(1)) # learnable gate to smoothly introduce Fourier features
|
||||||
|
self.ln = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, t): # t: (B, L) continuous timestamps/steps
|
||||||
|
# compute increments Δt and stabilize
|
||||||
|
dt = t - F.pad(t, (1, 0), value=0.)[:, :-1]
|
||||||
|
dt = torch.clamp(dt, min=0.) # ensure non-negative
|
||||||
|
# normalize/stabilize with log compression
|
||||||
|
t_scaled = torch.log1p(torch.clamp(torch.abs(t), max=self.dt_clip))
|
||||||
|
dt_scaled = torch.log1p(torch.clamp(dt, max=self.dt_clip))
|
||||||
|
|
||||||
|
scal = torch.stack([t_scaled, dt_scaled], dim=-1) # (B, L, 2)
|
||||||
|
scal_feat = self.scalar_proj(scal) # (B, L, D)
|
||||||
|
|
||||||
|
# Fixed-frequency sin/cos to capture absolute/relative periodicity
|
||||||
|
# If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant)
|
||||||
|
# (B, L, K)
|
||||||
|
wt = t[..., None] * self.freqs
|
||||||
|
sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1) # (B, L, 2K)
|
||||||
|
fourier_feat = self.fourier_proj(sincos) # (B, L, D)
|
||||||
|
|
||||||
|
# gated fusion + layer norm
|
||||||
|
h = scal_feat + torch.tanh(self.gate) * fourier_feat
|
||||||
|
return self.ln(h) # (B, L, D)
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
""" an unassuming Transformer block """
|
""" an unassuming Transformer block """
|
||||||
@@ -25,8 +101,10 @@ class Block(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||||
normed_x = self.ln_1(x)
|
normed_x = self.ln_1(x)
|
||||||
|
|
||||||
attn_mask = ~custom_mask
|
# Build an additive attention mask to avoid backend issues with boolean masks on some GPUs
|
||||||
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0)
|
# custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked.
|
||||||
|
mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask
|
||||||
|
attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9)
|
||||||
|
|
||||||
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
|
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
|
||||||
x = x + self.resid_dropout(attn_output)
|
x = x + self.resid_dropout(attn_output)
|
||||||
@@ -58,14 +136,8 @@ class AgeSinusoidalEncoding(nn.Module):
|
|||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
# Pre-calculate the divisor term for the sinusoidal formula.
|
# Pre-calculate the divisor term for the sinusoidal formula.
|
||||||
# The formula for the divisor is 10000^(2i/D), where D is the
|
|
||||||
# embedding_dim and i is the index for each pair of dimensions.
|
|
||||||
# i ranges from 0 to D/2 - 1.
|
|
||||||
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||||
divisor = torch.pow(10000, i / self.embedding_dim)
|
divisor = torch.pow(10000, i / self.embedding_dim)
|
||||||
|
|
||||||
# Register the divisor as a non-trainable buffer. This ensures it is
|
|
||||||
# moved to the correct device (e.g., GPU) along with the model.
|
|
||||||
self.register_buffer('divisor', divisor)
|
self.register_buffer('divisor', divisor)
|
||||||
|
|
||||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -80,49 +152,204 @@ class AgeSinusoidalEncoding(nn.Module):
|
|||||||
torch.Tensor: The encoded age tensor of shape
|
torch.Tensor: The encoded age tensor of shape
|
||||||
(batch_size, sequence_length, embedding_dim).
|
(batch_size, sequence_length, embedding_dim).
|
||||||
"""
|
"""
|
||||||
# 1. Unit Conversion: Convert age from days to years.
|
|
||||||
# We use 365.25 to account for leap years.
|
|
||||||
t_years = t / 365.25
|
t_years = t / 365.25
|
||||||
|
|
||||||
# 2. Argument Calculation: Calculate the arguments for the sin/cos functions.
|
|
||||||
# The shapes are broadcast to (B, L, D/2).
|
|
||||||
# Input t_years: (B, L) -> unsqueezed to (B, L, 1)
|
|
||||||
# Divisor: (D/2) -> viewed as (1, 1, D/2)
|
|
||||||
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1)
|
||||||
|
|
||||||
# 3. Sinusoidal Application: Create the final output tensor.
|
|
||||||
# Initialize an empty tensor to store the embeddings.
|
|
||||||
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device)
|
||||||
|
|
||||||
# Assign cosine of the arguments to the even indices.
|
|
||||||
output[:, :, 0::2] = torch.cos(args)
|
output[:, :, 0::2] = torch.cos(args)
|
||||||
|
|
||||||
# Assign sine of the arguments to the odd indices.
|
|
||||||
output[:, :, 1::2] = torch.sin(args)
|
output[:, :, 1::2] = torch.sin(args)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LearnableAgeEncoding(nn.Module):
|
||||||
|
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
|
||||||
|
|
||||||
|
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.base_dim = base_dim
|
||||||
|
self.final_dim = final_dim or base_dim
|
||||||
|
|
||||||
|
hidden_dim = hidden_dim or base_dim
|
||||||
|
if hidden_dim <= 0:
|
||||||
|
raise ValueError("hidden_dim must be a positive integer.")
|
||||||
|
if self.final_dim <= 0:
|
||||||
|
raise ValueError("final_dim must be a positive integer.")
|
||||||
|
|
||||||
|
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
|
||||||
|
|
||||||
|
mlp_layers = [
|
||||||
|
nn.Linear(base_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
]
|
||||||
|
if dropout > 0.0:
|
||||||
|
mlp_layers.append(nn.Dropout(dropout))
|
||||||
|
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(*mlp_layers)
|
||||||
|
|
||||||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
sin_embed = self.sinusoidal(t)
|
||||||
|
flat_embed = sin_embed.reshape(-1, self.base_dim)
|
||||||
|
projected = self.mlp(flat_embed)
|
||||||
|
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
|
||||||
|
|
||||||
|
class PiecewiseLinearEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Encodes continuous variables using piecewise linear encoding.
|
||||||
|
|
||||||
|
This module defines bins based on standard normal distribution quantiles,
|
||||||
|
encodes an input by finding its bin, and calculates its position as a
|
||||||
|
linear interpolation between boundaries. The result is projected to the
|
||||||
|
final embedding dimension by a shared linear layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_bins: int, embedding_dim: int):
|
||||||
|
"""
|
||||||
|
Initializes the PiecewiseLinearEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_bins (int): The number of bins for the encoding.
|
||||||
|
embedding_dim (int): The dimensionality of the output embedding (D).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if num_bins <= 0:
|
||||||
|
raise ValueError("num_bins must be a positive integer.")
|
||||||
|
self.num_bins = num_bins
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
if num_bins > 1:
|
||||||
|
quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1)
|
||||||
|
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||||
|
boundaries = normal_dist.icdf(quantiles)
|
||||||
|
else:
|
||||||
|
boundaries = torch.tensor([])
|
||||||
|
|
||||||
|
boundaries = torch.cat([
|
||||||
|
torch.tensor([float('-inf')]),
|
||||||
|
boundaries,
|
||||||
|
torch.tensor([float('inf')])
|
||||||
|
])
|
||||||
|
self.register_buffer('boundaries', boundaries)
|
||||||
|
|
||||||
|
self.linear = nn.Linear(num_bins, embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the piecewise linear encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (*, N), where * is any
|
||||||
|
number of batch dimensions and N is the number of continuous
|
||||||
|
features. Assumed to be pre-scaled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor of shape (*, N, D).
|
||||||
|
"""
|
||||||
|
original_shape = x.shape
|
||||||
|
x = x.reshape(-1, original_shape[-1])
|
||||||
|
|
||||||
|
bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1
|
||||||
|
bin_indices = bin_indices.clamp(0, self.num_bins - 1)
|
||||||
|
|
||||||
|
lower_bounds = self.boundaries[bin_indices]
|
||||||
|
upper_bounds = self.boundaries[bin_indices + 1]
|
||||||
|
delta = upper_bounds - lower_bounds + 1e-8
|
||||||
|
|
||||||
|
weight_upper = (x - lower_bounds) / delta
|
||||||
|
weight_lower = 1.0 - weight_upper
|
||||||
|
|
||||||
|
is_first_bin = (bin_indices == 0)
|
||||||
|
is_last_bin = (bin_indices == self.num_bins - 1)
|
||||||
|
|
||||||
|
weight_lower[is_first_bin] = 1.0
|
||||||
|
weight_upper[is_first_bin] = 0.0
|
||||||
|
weight_lower[is_last_bin] = 0.0
|
||||||
|
weight_upper[is_last_bin] = 1.0
|
||||||
|
|
||||||
|
encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype)
|
||||||
|
encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1))
|
||||||
|
|
||||||
|
upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1)
|
||||||
|
encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1))
|
||||||
|
|
||||||
|
encoded = encoded.view(*original_shape, self.num_bins)
|
||||||
|
output = self.linear(encoded)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class TemporalConvEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
x: (B, L) - event/token ids
|
||||||
|
t: (B, L) - timestamps (real-valued) or step indices
|
||||||
|
Output:
|
||||||
|
h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
d_model: int = 768,
|
||||||
|
n_layers: int = 2,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
fourier_dim: int = 32,
|
||||||
|
pad_id: int = 0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
|
||||||
|
self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim)
|
||||||
|
self.fuse = nn.Linear(2*d_model, d_model, bias=False) # fuse token and time features
|
||||||
|
self.ln_in = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout))
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
def forward(self, x, t, attention_mask=None):
|
||||||
|
"""
|
||||||
|
attention_mask: (B, L) 1=keep, 0=padding
|
||||||
|
"""
|
||||||
|
tok = self.token_emb(x) # (B, L, D)
|
||||||
|
tim = self.time_proj(t) # (B, L, D)
|
||||||
|
|
||||||
|
h = torch.cat([tok, tim], dim=-1) # (B, L, 2D)
|
||||||
|
h = self.fuse(h) # (B, L, D)
|
||||||
|
h = self.ln_in(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
|
||||||
|
# Optional: zero-out padding positions before convolutions to avoid leakage
|
||||||
|
if attention_mask is not None:
|
||||||
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||||
|
|
||||||
|
# Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context
|
||||||
|
for blk in self.blocks:
|
||||||
|
h = blk(h) # (B, L, D)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
h = h * attention_mask.unsqueeze(-1).type_as(h)
|
||||||
|
|
||||||
|
return h # (B, L, D), directly usable as attention layer input
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 2. Main Model Architectures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
class TimeAwareGPT2(nn.Module):
|
class TimeAwareGPT2(nn.Module):
|
||||||
"""
|
"""
|
||||||
A time-aware GPT-2 model with custom temporal features.
|
A time-aware GPT-2 model with custom temporal features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_pdrop = token_pdrop
|
self.token_pdrop = token_pdrop
|
||||||
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||||
|
|
||||||
# Token and positional embeddings
|
|
||||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||||
self.drop = nn.Dropout(pdrop)
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
|
||||||
# Final layer norm and linear head
|
|
||||||
self.ln_f = nn.LayerNorm(n_embd)
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
self.n_embd = n_embd
|
self.n_embd = n_embd
|
||||||
|
|
||||||
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -138,46 +365,30 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, L = event_seq.size()
|
B, L = event_seq.size()
|
||||||
|
|
||||||
# 1. Get token embeddings
|
|
||||||
token_embeddings = self.wte(event_seq)
|
token_embeddings = self.wte(event_seq)
|
||||||
|
|
||||||
# 2. Apply token dropout (only during training)
|
|
||||||
if self.training and self.token_pdrop > 0:
|
if self.training and self.token_pdrop > 0:
|
||||||
# Create a mask to randomly zero out entire token embedding vectors
|
|
||||||
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
|
||||||
token_embeddings[drop_mask] = 0.0
|
token_embeddings[drop_mask] = 0.0
|
||||||
|
|
||||||
# 3. Get positional embeddings from time sequence
|
|
||||||
pos_embeddings = self.age_encoder(time_seq.float())
|
pos_embeddings = self.age_encoder(time_seq.float())
|
||||||
|
|
||||||
# 4. Combine embeddings and apply dropout
|
|
||||||
x = self.drop(token_embeddings + pos_embeddings)
|
x = self.drop(token_embeddings + pos_embeddings)
|
||||||
|
|
||||||
# 5. Generate attention mask
|
t_i = time_seq.unsqueeze(-1)
|
||||||
# The attention mask combines two conditions:
|
t_j = time_seq.unsqueeze(1)
|
||||||
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i].
|
time_mask = (t_j < t_i)
|
||||||
# b) Padding mask: Do not attend to positions where the event token is 0.
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
||||||
|
|
||||||
# a) Time-based causal mask
|
|
||||||
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
|
||||||
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
|
||||||
time_mask = (t_j <= t_i)
|
|
||||||
|
|
||||||
# b) Padding mask (prevents attending to key positions that are padding)
|
|
||||||
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
|
||||||
|
|
||||||
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
|
||||||
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
|
||||||
combined_mask = time_mask & padding_mask
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
# 6. Pass through transformer blocks
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||||
|
is_not_padding = (event_seq != 0)
|
||||||
|
force_self_attention = is_row_all_zero & is_not_padding
|
||||||
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, custom_mask=combined_mask)
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
# 7. Final layer norm and projection to vocab size
|
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
logits = self.head(x)
|
logits = self.head(x)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def get_num_params(self) -> float:
|
def get_num_params(self) -> float:
|
||||||
@@ -186,6 +397,222 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
|
||||||
|
"""
|
||||||
|
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
|
||||||
|
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||||
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if termination_tokens is None:
|
||||||
|
termination_tokens = [1269]
|
||||||
|
|
||||||
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||||
|
mask_time = -10000
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
logits = self(x, t)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.ignore_tokens:
|
||||||
|
logits[:, self.ignore_tokens] = -torch.inf
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
logits = logits.scatter(1, fill, -torch.inf)
|
||||||
|
|
||||||
|
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
|
||||||
|
t_next_val, idx_next = t_next_dist.min(1)
|
||||||
|
|
||||||
|
idx_next = idx_next.unsqueeze(1)
|
||||||
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||||
|
|
||||||
|
x = torch.cat((x, idx_next), dim=1)
|
||||||
|
t = torch.cat((t, age_next), dim=1)
|
||||||
|
|
||||||
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||||
|
|
||||||
|
final_logits = self(x, t)
|
||||||
|
x[pad] = 0
|
||||||
|
t[pad] = mask_time
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
|
||||||
|
|
||||||
|
return x, t, final_logits
|
||||||
|
|
||||||
|
|
||||||
|
class TimeAwareGPT2Learnable(TimeAwareGPT2):
|
||||||
|
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.age_encoder = LearnableAgeEncoding(
|
||||||
|
base_dim=self.n_embd,
|
||||||
|
hidden_dim=2 * self.n_embd,
|
||||||
|
final_dim=self.n_embd,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 3. Loss Function
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TimeAwareGPT2TemporalConv(nn.Module):
|
||||||
|
"""
|
||||||
|
A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode
|
||||||
|
event and time sequences before Transformer attention blocks.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- event_seq: (B, L) token ids (0 treated as padding)
|
||||||
|
- time_seq: (B, L) timestamps or step indices (float)
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- logits: (B, L, vocab_size)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
n_embd: int,
|
||||||
|
n_layer: int,
|
||||||
|
n_head: int,
|
||||||
|
pdrop: float,
|
||||||
|
token_pdrop: float,
|
||||||
|
ignore_tokens: Optional[list[int]] = None,
|
||||||
|
*,
|
||||||
|
conv_layers: int = 2,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
conv_dropout: float = 0.1,
|
||||||
|
fourier_dim: int = 32,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_pdrop = token_pdrop
|
||||||
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||||
|
self.n_embd = n_embd
|
||||||
|
|
||||||
|
# Temporal convolutional encoder to build inputs_embeds
|
||||||
|
self.temporal_encoder = TemporalConvEncoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=n_embd,
|
||||||
|
n_layers=conv_layers,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dropout=conv_dropout,
|
||||||
|
fourier_dim=fourier_dim,
|
||||||
|
pad_id=pad_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer stack on top of temporal features
|
||||||
|
self.drop = nn.Dropout(pdrop)
|
||||||
|
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
|
||||||
|
self.ln_f = nn.LayerNorm(n_embd)
|
||||||
|
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, L = event_seq.size()
|
||||||
|
|
||||||
|
# Encoder features as inputs_embeds
|
||||||
|
attention_mask = (event_seq != 0)
|
||||||
|
x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask)
|
||||||
|
x = self.drop(x)
|
||||||
|
|
||||||
|
# Time-aware causal mask as before
|
||||||
|
t_i = time_seq.unsqueeze(-1)
|
||||||
|
t_j = time_seq.unsqueeze(1)
|
||||||
|
time_mask = (t_j < t_i)
|
||||||
|
padding_mask = (event_seq != 0).unsqueeze(1)
|
||||||
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
|
# Ensure at least self-attention on non-padding rows
|
||||||
|
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||||
|
is_not_padding = (event_seq != 0)
|
||||||
|
force_self_attention = is_row_all_zero & is_not_padding
|
||||||
|
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
logits = self.head(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_num_params(self) -> float:
|
||||||
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
max_age: float = 85 * 365.25,
|
||||||
|
no_repeat: bool = True,
|
||||||
|
termination_tokens: Optional[list[int]] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Greedy-like generation with optional no-repeat and termination tokens."""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if termination_tokens is None:
|
||||||
|
termination_tokens = [1269]
|
||||||
|
|
||||||
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||||
|
mask_time = -10000
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
logits = self(x, t)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.ignore_tokens:
|
||||||
|
logits[:, self.ignore_tokens] = -torch.inf
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
logits = logits.scatter(1, fill, -torch.inf)
|
||||||
|
|
||||||
|
# Sample a time increment proxy as in original implementation
|
||||||
|
t_next_dist = torch.clamp(
|
||||||
|
-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(),
|
||||||
|
min=0,
|
||||||
|
max=365 * 80,
|
||||||
|
)
|
||||||
|
t_next_val, idx_next = t_next_dist.min(1)
|
||||||
|
|
||||||
|
idx_next = idx_next.unsqueeze(1)
|
||||||
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||||
|
|
||||||
|
x = torch.cat((x, idx_next), dim=1)
|
||||||
|
t = torch.cat((t, age_next), dim=1)
|
||||||
|
|
||||||
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||||
|
|
||||||
|
final_logits = self(x, t)
|
||||||
|
x[pad] = 0
|
||||||
|
t[pad] = mask_time
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
final_logits = torch.stack(
|
||||||
|
[final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])]
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
return x, t, final_logits
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Computes a two-part loss: a standard cross-entropy loss for event type
|
Computes a two-part loss: a standard cross-entropy loss for event type
|
||||||
@@ -215,35 +642,23 @@ class CombinedLoss(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
A tuple containing the two scalar loss tensors: (loss_ce, loss_survival).
|
||||||
"""
|
"""
|
||||||
# 1. Create a mask to filter out ignored token IDs from loss calculation.
|
|
||||||
# An element is True if the corresponding label in x is NOT in the ignored list.
|
|
||||||
mask = torch.ones_like(x, dtype=torch.bool)
|
mask = torch.ones_like(x, dtype=torch.bool)
|
||||||
for token_id in self.ignored_token_ids:
|
for token_id in self.ignored_token_ids:
|
||||||
mask = mask & (x != token_id)
|
mask = mask & (x != token_id)
|
||||||
|
|
||||||
# If the mask is all False (all tokens are ignored), return zero for both losses.
|
|
||||||
if not mask.any():
|
if not mask.any():
|
||||||
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
|
||||||
|
|
||||||
# 2. Part 1: Cross-Entropy Loss (loss_ce)
|
|
||||||
# Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy.
|
|
||||||
logits_for_ce = logits.permute(0, 2, 1)
|
logits_for_ce = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
# Calculate per-element loss without reduction.
|
|
||||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||||
|
|
||||||
# Apply the mask and compute the mean of valid elements.
|
|
||||||
loss_ce = per_element_ce[mask].mean()
|
loss_ce = per_element_ce[mask].mean()
|
||||||
|
|
||||||
# 3. Part 2: Survival Loss (loss_survival)
|
# Survival loss based on exponential log-likelihood
|
||||||
# Calculate event intensity (lambda) as the sum of exponentiated logits.
|
t_min = 0.1
|
||||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
lse = torch.logsumexp(logits, dim=-1)
|
||||||
|
lse = -torch.log(torch.exp(-lse) + t_min)
|
||||||
# Calculate per-element survival loss (negative log-likelihood of exponential dist).
|
ldt = -torch.log(t + t_min)
|
||||||
# We add a small epsilon for numerical stability with the log.
|
loss_dt = -(lse - torch.exp(lse - ldt))
|
||||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
loss_survival = loss_dt[mask].mean()
|
||||||
|
|
||||||
# Apply the mask and compute the mean of valid elements.
|
|
||||||
loss_survival = per_element_survival[mask].mean()
|
|
||||||
|
|
||||||
return loss_ce, loss_survival
|
return loss_ce, loss_survival
|
||||||
|
160
plot_auc_boxplots_by_chapter.R
Normal file
160
plot_auc_boxplots_by_chapter.R
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Compare AUC distributions between models by ICD-10 chapter (1-year and no-gap)
|
||||||
|
# Usage:
|
||||||
|
# Rscript plot_auc_boxplots_by_chapter.R [one_year_csv] [no_gap_csv] [output_dir]
|
||||||
|
# Defaults:
|
||||||
|
# one_year_csv = "model_comparison_auc_1year.csv"
|
||||||
|
# no_gap_csv = "model_comparison_auc_no_gap.csv"
|
||||||
|
# output_dir = current working directory (".")
|
||||||
|
|
||||||
|
suppressPackageStartupMessages({
|
||||||
|
library(ggplot2)
|
||||||
|
library(cowplot)
|
||||||
|
})
|
||||||
|
|
||||||
|
args <- commandArgs(trailingOnly = TRUE)
|
||||||
|
one_year_csv <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
|
||||||
|
no_gap_csv <- if (length(args) >= 2) args[2] else "model_comparison_auc_no_gap.csv"
|
||||||
|
out_dir <- if (length(args) >= 3) args[3] else "."
|
||||||
|
orientation <- if (length(args) >= 4) tolower(args[4]) else "vertical" # "horizontal" (flipped) or "vertical"
|
||||||
|
|
||||||
|
if (!dir.exists(out_dir)) {
|
||||||
|
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
read_csv_safe <- function(path) {
|
||||||
|
tryCatch({
|
||||||
|
read.csv(path, check.names = FALSE)
|
||||||
|
}, error = function(e) {
|
||||||
|
stop(sprintf("Failed to read CSV at '%s': %s", path, e$message))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine a chapter column name robustly
|
||||||
|
get_chapter_col <- function(df) {
|
||||||
|
candidates <- c("ICD-10 Chapter (short)", "ICD-10 Chapter", "ICD10_chapter", "chapter", "ICD_chapter")
|
||||||
|
for (c in candidates) {
|
||||||
|
if (c %in% names(df)) return(c)
|
||||||
|
}
|
||||||
|
return(NA_character_)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute a deterministic chapter ordering using the ICD-10 chapter numeral prefix
|
||||||
|
# e.g., "I. Infectious Diseases", "II. Neoplasms", ..., "XVII. ...", with a fallback for "Death" and unknowns
|
||||||
|
compute_chapter_levels <- function(chapters) {
|
||||||
|
ch <- as.character(chapters)
|
||||||
|
roman_levels <- c(
|
||||||
|
"I","II","III","IV","V","VI","VII","VIII","IX","X",
|
||||||
|
"XI","XII","XIII","XIV","XV","XVI","XVII","XVIII","XIX","XX"
|
||||||
|
)
|
||||||
|
roman_map <- setNames(seq_along(roman_levels), roman_levels)
|
||||||
|
# Extract leading Roman numeral before a dot, like "XVI." -> "XVI"
|
||||||
|
roman <- toupper(gsub("^\\s*([IVXLCDM]+)\\..*$", "\\1", ch))
|
||||||
|
idx <- rep(NA_integer_, length(ch))
|
||||||
|
hit <- roman %in% names(roman_map)
|
||||||
|
idx[hit] <- roman_map[roman[hit]]
|
||||||
|
# Special-case Death at the end
|
||||||
|
idx[grepl("^\\s*Death\\b", ch, ignore.case = TRUE)] <- 99L
|
||||||
|
# Unknowns to the very end
|
||||||
|
idx[is.na(idx)] <- 100L
|
||||||
|
# Order chapters by idx, stable within same idx by appearance
|
||||||
|
o <- order(idx, match(ch, unique(ch)))
|
||||||
|
unique(ch[o])
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build long-format data.frame with columns: chapter, model, auc
|
||||||
|
# It will include any of the known model columns that exist in the input df
|
||||||
|
build_long_df <- function(df) {
|
||||||
|
model_cols <- c(
|
||||||
|
auc_120 = "auc_120",
|
||||||
|
auc_120_l = "auc_120_l",
|
||||||
|
auc_256 = "auc_256",
|
||||||
|
auc_256_l = "auc_256_l",
|
||||||
|
auc_delphi = "auc_delphi"
|
||||||
|
)
|
||||||
|
pretty_names <- c(
|
||||||
|
auc_120 = "GPT-2 120",
|
||||||
|
auc_120_l = "GPT-2 120_L",
|
||||||
|
auc_256 = "GPT-2 256",
|
||||||
|
auc_256_l = "GPT-2 256_L",
|
||||||
|
auc_delphi = "Delphi"
|
||||||
|
)
|
||||||
|
present <- model_cols[names(model_cols) %in% names(df)]
|
||||||
|
if (length(present) == 0) stop("No known AUC columns found in input data.")
|
||||||
|
chap_col <- get_chapter_col(df)
|
||||||
|
if (is.na(chap_col)) {
|
||||||
|
warning("No chapter column found; using a single 'All' group.")
|
||||||
|
chapters <- rep("All", nrow(df))
|
||||||
|
} else {
|
||||||
|
chapters <- df[[chap_col]]
|
||||||
|
}
|
||||||
|
out_list <- list()
|
||||||
|
for (key in names(model_cols)) {
|
||||||
|
col <- model_cols[[key]]
|
||||||
|
if (col %in% names(df)) {
|
||||||
|
out_list[[length(out_list) + 1]] <- data.frame(
|
||||||
|
chapter = chapters,
|
||||||
|
model = pretty_names[[key]],
|
||||||
|
auc = as.numeric(df[[col]]),
|
||||||
|
stringsAsFactors = FALSE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long_df <- do.call(rbind, out_list)
|
||||||
|
# Filter out-of-range or NA
|
||||||
|
long_df <- long_df[is.finite(long_df$auc) & long_df$auc >= 0 & long_df$auc <= 1, ]
|
||||||
|
long_df$model <- factor(long_df$model, levels = c("GPT-2 120", "GPT-2 120_L", "GPT-2 256", "GPT-2 256_L", "Delphi"))
|
||||||
|
return(long_df)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make the boxplot grouped by chapter
|
||||||
|
make_boxplot <- function(long_df, title_text, flip = TRUE) {
|
||||||
|
# Order chapters by their ICD-10 chapter number prefix (Roman numerals)
|
||||||
|
chap_levels <- compute_chapter_levels(long_df$chapter)
|
||||||
|
long_df$chapter <- factor(long_df$chapter, levels = chap_levels)
|
||||||
|
|
||||||
|
p <- ggplot(long_df, aes(x = chapter, y = auc, fill = model)) +
|
||||||
|
geom_boxplot(outlier.shape = 19, outlier.size = 0.7, width = 0.75, alpha = 0.95) +
|
||||||
|
scale_y_continuous(limits = c(0.3, 1.0), breaks = seq(0.3, 1.0, by = 0.1)) +
|
||||||
|
labs(title = title_text, x = "ICD-10 Chapter", y = "AUC") +
|
||||||
|
theme_minimal(base_size = 11) +
|
||||||
|
theme(
|
||||||
|
plot.title = element_text(hjust = 0.5),
|
||||||
|
panel.grid.minor = element_blank(),
|
||||||
|
legend.position = "bottom"
|
||||||
|
) +
|
||||||
|
guides(fill = guide_legend(nrow = 1))
|
||||||
|
if (flip) {
|
||||||
|
p <- p + coord_flip()
|
||||||
|
} else {
|
||||||
|
# For vertical plots, angle x-axis labels for readability
|
||||||
|
p <- p + theme(axis.text.x = element_text(angle = 45, hjust = 1))
|
||||||
|
}
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build plots for 1-year and no-gap
|
||||||
|
one_year_df <- read_csv_safe(one_year_csv)
|
||||||
|
no_gap_df <- read_csv_safe(no_gap_csv)
|
||||||
|
|
||||||
|
one_year_long <- build_long_df(one_year_df)
|
||||||
|
no_gap_long <- build_long_df(no_gap_df)
|
||||||
|
|
||||||
|
flip_flag <- ifelse(orientation %in% c("horizontal", "flip", "flipped"), TRUE, FALSE)
|
||||||
|
|
||||||
|
p1 <- make_boxplot(one_year_long, "AUC by ICD-10 Chapter (1-year gap)", flip = flip_flag)
|
||||||
|
p2 <- make_boxplot(no_gap_long, "AUC by ICD-10 Chapter (no gap)", flip = flip_flag)
|
||||||
|
|
||||||
|
# Save individual plots
|
||||||
|
out_1year <- file.path(out_dir, "auc_boxplot_by_chapter_1year.png")
|
||||||
|
ggsave(out_1year, p1, width = 12, height = 10, dpi = 300, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_1year))
|
||||||
|
|
||||||
|
out_nogap <- file.path(out_dir, "auc_boxplot_by_chapter_no_gap.png")
|
||||||
|
ggsave(out_nogap, p2, width = 12, height = 10, dpi = 300, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_nogap))
|
||||||
|
|
||||||
|
# Save a side-by-side grid for quick comparison
|
||||||
|
grid <- plot_grid(p1, p2, labels = c("A", "B"), ncol = 2, align = "hv")
|
||||||
|
out_grid <- file.path(out_dir, "auc_boxplot_by_chapter_grid.png")
|
||||||
|
ggsave(out_grid, grid, width = 18, height = 10, dpi = 250, bg = "white")
|
||||||
|
cat(sprintf("Saved grid: %s\n", out_grid))
|
125
plot_model_comparison_1year.R
Normal file
125
plot_model_comparison_1year.R
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Plot AUC comparisons (1-year gap) between models and Delphi using ggplot2
|
||||||
|
# Usage:
|
||||||
|
# Rscript plot_model_comparison_1year.R [path_to_csv] [output_dir]
|
||||||
|
# Defaults:
|
||||||
|
# path_to_csv = "model_comparison_auc_1year.csv"
|
||||||
|
# output_dir = current working directory (".")
|
||||||
|
|
||||||
|
suppressPackageStartupMessages({
|
||||||
|
library(ggplot2)
|
||||||
|
library(cowplot)
|
||||||
|
})
|
||||||
|
|
||||||
|
args <- commandArgs(trailingOnly = TRUE)
|
||||||
|
csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
|
||||||
|
out_dir <- if (length(args) >= 2) args[2] else "."
|
||||||
|
|
||||||
|
if (!dir.exists(out_dir)) {
|
||||||
|
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
# Expect columns including: auc_delphi, auc_256, auc_120, Colour (hex color), name, etc.
|
||||||
|
df <- tryCatch({
|
||||||
|
read.csv(csv_path, check.names = FALSE)
|
||||||
|
}, error = function(e) {
|
||||||
|
stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message))
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Helper to compare any two AUC columns (x vs y)
|
||||||
|
make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) {
|
||||||
|
ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) +
|
||||||
|
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
|
||||||
|
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
|
||||||
|
scale_fill_identity() +
|
||||||
|
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
|
||||||
|
coord_fixed(ratio = 1) +
|
||||||
|
labs(title = title_text, x = x_label, y = y_label) +
|
||||||
|
theme_minimal(base_size = 10) +
|
||||||
|
theme(
|
||||||
|
plot.title = element_text(hjust = 0.5),
|
||||||
|
panel.grid.minor = element_blank()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Helper to compare model AUC vs Delphi AUC (x = auc_delphi)
|
||||||
|
make_delphi_plot <- function(data, y_col, title_text, y_label) {
|
||||||
|
ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) +
|
||||||
|
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
|
||||||
|
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
|
||||||
|
scale_fill_identity() +
|
||||||
|
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
|
||||||
|
coord_fixed(ratio = 1) +
|
||||||
|
labs(title = title_text, x = "AUC_Delphi", y = y_label) +
|
||||||
|
theme_minimal(base_size = 10) +
|
||||||
|
theme(
|
||||||
|
plot.title = element_text(hjust = 0.5),
|
||||||
|
panel.grid.minor = element_blank()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Placeholder empty plot if a required column is missing
|
||||||
|
empty_plot <- function(msg) {
|
||||||
|
ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Plot: AUC_120 vs AUC_120_L (1 year gap)
|
||||||
|
if (!all(c("auc_120", "auc_120_l") %in% names(df))) {
|
||||||
|
warning("Columns 'auc_120' and/or 'auc_120_l' not found in CSV; skipping AUC_120 vs AUC_120_L plot.")
|
||||||
|
} else {
|
||||||
|
p120_vs_120l <- make_xy_plot(
|
||||||
|
data = df,
|
||||||
|
x_col = "auc_120",
|
||||||
|
y_col = "auc_120_l",
|
||||||
|
title_text = "AUC_120 vs AUC_120_L 1 year gap",
|
||||||
|
x_label = "AUC_120",
|
||||||
|
y_label = "AUC_120_L"
|
||||||
|
)
|
||||||
|
out_120_vs_120l <- file.path(out_dir, "model_comparison_auc_120_vs_120_l_1year.png")
|
||||||
|
ggsave(filename = out_120_vs_120l, plot = p120_vs_120l, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_120_vs_120l))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Plot: AUC_256 vs AUC_256_L (1 year gap)
|
||||||
|
if (!all(c("auc_256", "auc_256_l") %in% names(df))) {
|
||||||
|
warning("Columns 'auc_256' and/or 'auc_256_l' not found in CSV; skipping AUC_256 vs AUC_256_L plot.")
|
||||||
|
} else {
|
||||||
|
p256_vs_256l <- make_xy_plot(
|
||||||
|
data = df,
|
||||||
|
x_col = "auc_256",
|
||||||
|
y_col = "auc_256_l",
|
||||||
|
title_text = "AUC_256 vs AUC_256_L 1 year gap",
|
||||||
|
x_label = "AUC_256",
|
||||||
|
y_label = "AUC_256_L"
|
||||||
|
)
|
||||||
|
out_256_vs_256l <- file.path(out_dir, "model_comparison_auc_256_vs_256_l_1year.png")
|
||||||
|
ggsave(filename = out_256_vs_256l, plot = p256_vs_256l, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_256_vs_256l))
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---- Combined 2x2 grid: (auc_120 vs delphi), (auc_256 vs delphi), (auc_120_l vs delphi), (auc_256_l vs delphi) ----
|
||||||
|
|
||||||
|
has_cols <- function(cols) all(cols %in% names(df))
|
||||||
|
|
||||||
|
p_120_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120"))) make_delphi_plot(df, "auc_120", "AUC_120 vs Delphi (1 year)", "AUC_120") else empty_plot("Missing auc_120 or auc_delphi")
|
||||||
|
p_256_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256"))) make_delphi_plot(df, "auc_256", "AUC_256 vs Delphi (1 year)", "AUC_256") else empty_plot("Missing auc_256 or auc_delphi")
|
||||||
|
p_120l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (1 year)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi")
|
||||||
|
p_256l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (1 year)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi")
|
||||||
|
|
||||||
|
grid_plot <- plot_grid(
|
||||||
|
p_120_vs_delphi, p_256_vs_delphi,
|
||||||
|
p_120l_vs_delphi, p_256l_vs_delphi,
|
||||||
|
labels = c("A", "B", "C", "D"),
|
||||||
|
ncol = 2, align = "hv"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_1year_grid.png")
|
||||||
|
ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white")
|
||||||
|
cat(sprintf("Saved grid: %s\n", out_grid))
|
128
plot_model_comparison_no_gap.R
Normal file
128
plot_model_comparison_no_gap.R
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# Plot AUC comparisons (no gap) between models and Delphi using ggplot2
|
||||||
|
# Usage:
|
||||||
|
# Rscript plot_model_comparison_no_gap.R [path_to_csv] [output_dir]
|
||||||
|
# Defaults:
|
||||||
|
# path_to_csv = "model_comparison_auc_no_gap.csv"
|
||||||
|
# output_dir = current working directory (".")
|
||||||
|
|
||||||
|
suppressPackageStartupMessages({
|
||||||
|
library(ggplot2)
|
||||||
|
library(cowplot)
|
||||||
|
})
|
||||||
|
|
||||||
|
args <- commandArgs(trailingOnly = TRUE)
|
||||||
|
csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_no_gap.csv"
|
||||||
|
out_dir <- if (length(args) >= 2) args[2] else "."
|
||||||
|
|
||||||
|
if (!dir.exists(out_dir)) {
|
||||||
|
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
# Expect columns including: auc_delphi, auc_256, auc_120, auc_256_l, auc_120_l, Colour (hex color), name, etc.
|
||||||
|
df <- tryCatch({
|
||||||
|
read.csv(csv_path, check.names = FALSE)
|
||||||
|
}, error = function(e) {
|
||||||
|
stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Helper to compare any two AUC columns (x vs y)
|
||||||
|
make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) {
|
||||||
|
ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) +
|
||||||
|
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
|
||||||
|
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
|
||||||
|
scale_fill_identity() +
|
||||||
|
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
|
||||||
|
coord_fixed(ratio = 1) +
|
||||||
|
labs(title = title_text, x = x_label, y = y_label) +
|
||||||
|
theme_minimal(base_size = 10) +
|
||||||
|
theme(
|
||||||
|
plot.title = element_text(hjust = 0.5),
|
||||||
|
panel.grid.minor = element_blank()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Helper to compare model AUC vs Delphi AUC (x = auc_delphi)
|
||||||
|
make_delphi_plot <- function(data, y_col, title_text, y_label) {
|
||||||
|
ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) +
|
||||||
|
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
|
||||||
|
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||||
|
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
|
||||||
|
scale_fill_identity() +
|
||||||
|
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
|
||||||
|
coord_fixed(ratio = 1) +
|
||||||
|
labs(title = title_text, x = "AUC_Delphi", y = y_label) +
|
||||||
|
theme_minimal(base_size = 10) +
|
||||||
|
theme(
|
||||||
|
plot.title = element_text(hjust = 0.5),
|
||||||
|
panel.grid.minor = element_blank()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Placeholder empty plot if a required column is missing
|
||||||
|
empty_plot <- function(msg) {
|
||||||
|
ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Individual Delphi comparison plots
|
||||||
|
has_cols <- function(cols) all(cols %in% names(df))
|
||||||
|
|
||||||
|
# AUC_120 vs AUC_Delphi (no gap)
|
||||||
|
if (has_cols(c("auc_delphi", "auc_120"))) {
|
||||||
|
p120 <- make_delphi_plot(df, "auc_120", "AUC_120 vs AUC_Delphi (no gap)", "AUC_120")
|
||||||
|
out_120 <- file.path(out_dir, "fig_auc_120_vs_delphi_no_gap.png")
|
||||||
|
ggsave(filename = out_120, plot = p120, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_120))
|
||||||
|
} else {
|
||||||
|
warning("Missing columns for AUC_120 vs Delphi plot.")
|
||||||
|
}
|
||||||
|
|
||||||
|
# AUC_256 vs AUC_Delphi (no gap)
|
||||||
|
if (has_cols(c("auc_delphi", "auc_256"))) {
|
||||||
|
p256 <- make_delphi_plot(df, "auc_256", "AUC_256 vs AUC_Delphi (no gap)", "AUC_256")
|
||||||
|
out_256 <- file.path(out_dir, "model_comparison_auc_256_vs_delphi_no_gap.png")
|
||||||
|
ggsave(filename = out_256, plot = p256, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_256))
|
||||||
|
} else {
|
||||||
|
warning("Missing columns for AUC_256 vs Delphi plot.")
|
||||||
|
}
|
||||||
|
|
||||||
|
# AUC_120_L vs AUC_Delphi (no gap)
|
||||||
|
if (has_cols(c("auc_delphi", "auc_120_l"))) {
|
||||||
|
p120l <- make_delphi_plot(df, "auc_120_l", "AUC_120_L vs AUC_Delphi (no gap)", "AUC_120_L")
|
||||||
|
out_120l <- file.path(out_dir, "fig_auc_120_l_vs_delphi_no_gap.png")
|
||||||
|
ggsave(filename = out_120l, plot = p120l, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_120l))
|
||||||
|
} else {
|
||||||
|
warning("Missing columns for AUC_120_L vs Delphi plot.")
|
||||||
|
}
|
||||||
|
|
||||||
|
# AUC_256_L vs AUC_Delphi (no gap)
|
||||||
|
if (has_cols(c("auc_delphi", "auc_256_l"))) {
|
||||||
|
p256l <- make_delphi_plot(df, "auc_256_l", "AUC_256_L vs AUC_Delphi (no gap)", "AUC_256_L")
|
||||||
|
out_256l <- file.path(out_dir, "model_comparison_auc_256_l_vs_delphi_no_gap.png")
|
||||||
|
ggsave(filename = out_256l, plot = p256l, width = 7, height = 4, dpi = 600, bg = "white")
|
||||||
|
cat(sprintf("Saved: %s\n", out_256l))
|
||||||
|
} else {
|
||||||
|
warning("Missing columns for AUC_256_L vs Delphi plot.")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2x2 grid of Delphi comparisons
|
||||||
|
p_120_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120"))) make_delphi_plot(df, "auc_120", "AUC_120 vs Delphi (no gap)", "AUC_120") else empty_plot("Missing auc_120 or auc_delphi")
|
||||||
|
p_256_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256"))) make_delphi_plot(df, "auc_256", "AUC_256 vs Delphi (no gap)", "AUC_256") else empty_plot("Missing auc_256 or auc_delphi")
|
||||||
|
p_120l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (no gap)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi")
|
||||||
|
p_256l_vs_delphi <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (no gap)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi")
|
||||||
|
|
||||||
|
grid_plot <- plot_grid(
|
||||||
|
p_120_vs_delphi, p_256_vs_delphi,
|
||||||
|
p_120l_vs_delphi, p_256l_vs_delphi,
|
||||||
|
labels = c("A", "B", "C", "D"),
|
||||||
|
ncol = 2, align = "hv"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_no_gap_grid.png")
|
||||||
|
ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white")
|
||||||
|
cat(sprintf("Saved grid: %s\n", out_grid))
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
torch
|
||||||
|
numpy
|
||||||
|
tqdm
|
||||||
|
matplotlib
|
||||||
|
joblib
|
91
train.py
91
train.py
@@ -1,13 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Adam
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import tqdm
|
import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
from models import TimeAwareGPT2, CombinedLoss
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss
|
||||||
from utils import PatientEventDataset
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
@@ -15,33 +17,80 @@ class TrainConfig:
|
|||||||
# Data parameters
|
# Data parameters
|
||||||
train_data_path = 'ukb_real_train.bin'
|
train_data_path = 'ukb_real_train.bin'
|
||||||
val_data_path = 'ukb_real_val.bin'
|
val_data_path = 'ukb_real_val.bin'
|
||||||
block_length = 24 # Sequence length
|
block_length = 48 # Sequence length
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
n_embd = 256
|
n_embd = 120
|
||||||
n_layer = 8
|
n_layer = 12
|
||||||
n_head = 8
|
n_head = 12
|
||||||
pdrop = 0.1
|
pdrop = 0.1
|
||||||
token_pdrop = 0.1
|
token_pdrop = 0.1
|
||||||
|
model_name = 'TimeAwareGPT2'
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
max_epoch = 200
|
max_epoch = 200
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
lr_initial = 6e-4
|
lr_initial = 6e-4
|
||||||
lr_final = 6e-5
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
warmup_epochs = 10
|
warmup_epochs = 10
|
||||||
early_stopping_patience = 5
|
early_stopping_patience = 10
|
||||||
|
betas = (0.9, 0.99)
|
||||||
|
|
||||||
# Loss parameters
|
# Loss parameters
|
||||||
# 0 = padding, 1 = "no event"
|
# 0 = padding, 1 = "no event"
|
||||||
ignored_token_ids = [0, 1]
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
|
||||||
|
|
||||||
# System parameters
|
# System parameters
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
# --- Main Training Script ---
|
# --- Main Training Script ---
|
||||||
def main():
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
|
||||||
|
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
|
||||||
|
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
|
||||||
|
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
|
||||||
|
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
|
||||||
|
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
|
||||||
|
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
|
||||||
|
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
|
||||||
|
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
|
||||||
|
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||||
|
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||||
|
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||||
|
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = TrainConfig()
|
config = TrainConfig()
|
||||||
|
config.n_layer = args.n_layer
|
||||||
|
config.n_embd = args.n_embd
|
||||||
|
config.n_head = args.n_head
|
||||||
|
config.max_epoch = args.max_epoch
|
||||||
|
config.batch_size = args.batch_size
|
||||||
|
config.lr_initial = args.lr_initial
|
||||||
|
config.lr_final = args.lr_final
|
||||||
|
config.weight_decay = args.weight_decay
|
||||||
|
config.warmup_epochs = args.warmup_epochs
|
||||||
|
config.early_stopping_patience = args.early_stopping_patience
|
||||||
|
config.pdrop = args.pdrop
|
||||||
|
config.token_pdrop = args.token_pdrop
|
||||||
|
config.betas = tuple(args.betas)
|
||||||
|
config.model_name = args.model
|
||||||
|
|
||||||
|
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
|
||||||
|
model_filename = f"best_model_{model_suffix}.pt"
|
||||||
|
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
|
||||||
|
|
||||||
|
# --- 0. Save Configuration ---
|
||||||
|
# Include model class in config filename for clarity/distinction across architectures
|
||||||
|
config_filename = f"config_{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||||
|
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||||
|
with open(config_filename, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
print(f"Configuration saved to {config_filename}")
|
||||||
|
|
||||||
# --- 1. Data Loading ---
|
# --- 1. Data Loading ---
|
||||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
@@ -60,7 +109,13 @@ def main():
|
|||||||
|
|
||||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
print(f"Initializing model on {config.device}...")
|
print(f"Initializing model on {config.device}...")
|
||||||
model = TimeAwareGPT2(
|
model_cls = {
|
||||||
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
|
||||||
|
}[config.model_name]
|
||||||
|
|
||||||
|
model = model_cls(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_embd=config.n_embd,
|
n_embd=config.n_embd,
|
||||||
n_layer=config.n_layer,
|
n_layer=config.n_layer,
|
||||||
@@ -72,7 +127,7 @@ def main():
|
|||||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
|
||||||
|
|
||||||
# --- 3. Training Loop ---
|
# --- 3. Training Loop ---
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
@@ -170,7 +225,7 @@ def main():
|
|||||||
best_val_loss = total_val_loss
|
best_val_loss = total_val_loss
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||||
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
torch.save(model.state_dict(), checkpoint_filename)
|
||||||
else:
|
else:
|
||||||
if epoch >= config.warmup_epochs:
|
if epoch >= config.warmup_epochs:
|
||||||
patience_counter += 1
|
patience_counter += 1
|
||||||
@@ -183,12 +238,20 @@ def main():
|
|||||||
# --- Save Best Model at the End ---
|
# --- Save Best Model at the End ---
|
||||||
if best_val_loss != float('inf'):
|
if best_val_loss != float('inf'):
|
||||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||||
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
model.load_state_dict(torch.load(checkpoint_filename))
|
||||||
print("Saving final best model to best_model.pt")
|
print(f"Saving final best model to {model_filename}")
|
||||||
torch.save(model.state_dict(), 'best_model.pt')
|
torch.save(model.state_dict(), model_filename)
|
||||||
else:
|
else:
|
||||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||||
|
|
||||||
|
# --- Save losses to a txt file ---
|
||||||
|
losses_filename = f"losses_{model_suffix}.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
# --- Plot and Save Loss Curves ---
|
# --- Plot and Save Loss Curves ---
|
||||||
num_epochs = len(train_losses_total)
|
num_epochs = len(train_losses_total)
|
||||||
epochs = range(1, num_epochs + 1)
|
epochs = range(1, num_epochs + 1)
|
||||||
|
364
train_ddp.py
Normal file
364
train_ddp.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
|
||||||
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
|
|
||||||
|
class TrainConfig:
|
||||||
|
# Data parameters
|
||||||
|
train_data_path = 'ukb_real_train.bin'
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
block_length = 48
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
n_embd = 120
|
||||||
|
n_layer = 12
|
||||||
|
n_head = 12
|
||||||
|
pdrop = 0.1
|
||||||
|
token_pdrop = 0.1
|
||||||
|
model_name = 'TimeAwareGPT2'
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
max_epoch = 200
|
||||||
|
batch_size = 128
|
||||||
|
lr_initial = 6e-4
|
||||||
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
|
warmup_epochs = 10
|
||||||
|
early_stopping_patience = 10
|
||||||
|
betas = (0.9, 0.99)
|
||||||
|
|
||||||
|
# Loss parameters (ignored tokens)
|
||||||
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_ddp(backend: str | None = None):
|
||||||
|
"""Initialize torch.distributed from environment variables set by torchrun."""
|
||||||
|
if backend is None:
|
||||||
|
if torch.cuda.is_available() and os.name != 'nt':
|
||||||
|
backend = 'nccl'
|
||||||
|
else:
|
||||||
|
backend = 'gloo'
|
||||||
|
dist.init_process_group(backend=backend)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||||
|
rank = int(os.environ.get('RANK', 0))
|
||||||
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device(f'cuda:{local_rank}')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
return rank, world_size, local_rank, device
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_ddp():
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_lr(epoch: int, cfg: TrainConfig) -> float:
|
||||||
|
if epoch < cfg.warmup_epochs:
|
||||||
|
return cfg.lr_initial
|
||||||
|
progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs))
|
||||||
|
return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress))
|
||||||
|
|
||||||
|
|
||||||
|
def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor:
|
||||||
|
"""All-reduce sum then divide by world_size."""
|
||||||
|
value = value.clone().to(torch.float64)
|
||||||
|
dist.all_reduce(value, op=dist.ReduceOp.SUM)
|
||||||
|
value /= world_size
|
||||||
|
return value.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.')
|
||||||
|
parser.add_argument('--n_layer', type=int, default=12)
|
||||||
|
parser.add_argument('--n_embd', type=int, default=120)
|
||||||
|
parser.add_argument('--n_head', type=int, default=12)
|
||||||
|
parser.add_argument('--max_epoch', type=int, default=200)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=128)
|
||||||
|
parser.add_argument('--lr_initial', type=float, default=6e-4)
|
||||||
|
parser.add_argument('--lr_final', type=float, default=6e-5)
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=2e-1)
|
||||||
|
parser.add_argument('--warmup_epochs', type=int, default=10)
|
||||||
|
parser.add_argument('--early_stopping_patience', type=int, default=10)
|
||||||
|
parser.add_argument('--pdrop', type=float, default=0.1)
|
||||||
|
parser.add_argument('--token_pdrop', type=float, default=0.1)
|
||||||
|
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99])
|
||||||
|
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2')
|
||||||
|
parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
rank, world_size, local_rank, device = setup_ddp(args.backend)
|
||||||
|
|
||||||
|
# Build config
|
||||||
|
cfg = TrainConfig()
|
||||||
|
cfg.n_layer = args.n_layer
|
||||||
|
cfg.n_embd = args.n_embd
|
||||||
|
cfg.n_head = args.n_head
|
||||||
|
cfg.max_epoch = args.max_epoch
|
||||||
|
cfg.batch_size = args.batch_size
|
||||||
|
cfg.lr_initial = args.lr_initial
|
||||||
|
cfg.lr_final = args.lr_final
|
||||||
|
cfg.weight_decay = args.weight_decay
|
||||||
|
cfg.warmup_epochs = args.warmup_epochs
|
||||||
|
cfg.early_stopping_patience = args.early_stopping_patience
|
||||||
|
cfg.pdrop = args.pdrop
|
||||||
|
cfg.token_pdrop = args.token_pdrop
|
||||||
|
cfg.betas = tuple(args.betas)
|
||||||
|
cfg.model_name = args.model
|
||||||
|
|
||||||
|
# Filenames (shared across ranks)
|
||||||
|
model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}"
|
||||||
|
model_filename = f"best_model_{model_suffix}.pt"
|
||||||
|
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
|
||||||
|
config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json"
|
||||||
|
|
||||||
|
# Save config only on rank 0
|
||||||
|
if rank == 0:
|
||||||
|
with open(config_filename, 'w') as f:
|
||||||
|
json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4)
|
||||||
|
print(f"[rank 0] Configuration saved to {config_filename}")
|
||||||
|
|
||||||
|
# Load data (all ranks)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...")
|
||||||
|
train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
|
||||||
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Inferred vocabulary size: {vocab_size}")
|
||||||
|
|
||||||
|
train_dataset = PatientEventDataset(train_data_arr, cfg.block_length)
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, cfg.block_length)
|
||||||
|
|
||||||
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
||||||
|
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
|
||||||
|
|
||||||
|
# Model, loss, optimizer
|
||||||
|
model_cls = {
|
||||||
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
}[cfg.model_name]
|
||||||
|
|
||||||
|
model = model_cls(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=cfg.n_embd,
|
||||||
|
n_layer=cfg.n_layer,
|
||||||
|
n_head=cfg.n_head,
|
||||||
|
pdrop=cfg.pdrop,
|
||||||
|
token_pdrop=cfg.token_pdrop,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None)
|
||||||
|
|
||||||
|
loss_fn = CombinedLoss(cfg.ignored_token_ids)
|
||||||
|
optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas)
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
|
||||||
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||||
|
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("Starting DDP training...")
|
||||||
|
|
||||||
|
for epoch in range(cfg.max_epoch):
|
||||||
|
# Update sampler epoch for shuffling
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
val_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
# Set LR
|
||||||
|
lr = cosine_lr(epoch, cfg)
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
pg['lr'] = lr
|
||||||
|
|
||||||
|
# Train
|
||||||
|
ddp_model.train()
|
||||||
|
train_loss_ce_acc = torch.zeros(1, device=device)
|
||||||
|
train_loss_surv_acc = torch.zeros(1, device=device)
|
||||||
|
train_steps = 0
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0))
|
||||||
|
for event_seq, time_seq in pbar:
|
||||||
|
event_seq = event_seq.to(device, non_blocking=True)
|
||||||
|
time_seq = time_seq.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = ddp_model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_loss_ce_acc += loss_ce.detach()
|
||||||
|
train_loss_surv_acc += loss_survival.detach()
|
||||||
|
train_steps += 1
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||||
|
|
||||||
|
# Aggregate train losses across ranks
|
||||||
|
if train_steps == 0:
|
||||||
|
train_steps = 1
|
||||||
|
steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64)
|
||||||
|
dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM)
|
||||||
|
train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size)
|
||||||
|
train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
train_losses_ce.append(train_loss_ce_mean.item())
|
||||||
|
train_losses_surv.append(train_loss_surv_mean.item())
|
||||||
|
train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item())
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
ddp_model.eval()
|
||||||
|
val_loss_ce_acc = torch.zeros(1, device=device)
|
||||||
|
val_loss_surv_acc = torch.zeros(1, device=device)
|
||||||
|
val_steps = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0))
|
||||||
|
for event_seq, time_seq in pbar_val:
|
||||||
|
event_seq = event_seq.to(device, non_blocking=True)
|
||||||
|
time_seq = time_seq.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = ddp_model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
|
||||||
|
val_loss_ce_acc += loss_ce.detach()
|
||||||
|
val_loss_surv_acc += loss_survival.detach()
|
||||||
|
val_steps += 1
|
||||||
|
|
||||||
|
if val_steps == 0:
|
||||||
|
val_steps = 1
|
||||||
|
vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64)
|
||||||
|
dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM)
|
||||||
|
val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size)
|
||||||
|
val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size)
|
||||||
|
total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
val_losses_ce.append(val_loss_ce_mean.item())
|
||||||
|
val_losses_surv.append(val_loss_surv_mean.item())
|
||||||
|
val_losses_total.append(total_val_loss)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch+1} Summary:\n"
|
||||||
|
f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n"
|
||||||
|
f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n"
|
||||||
|
f" Learning Rate: {lr:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Early stopping on rank 0; broadcast decision
|
||||||
|
improved = total_val_loss < best_val_loss
|
||||||
|
if improved:
|
||||||
|
best_val_loss = total_val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||||
|
torch.save(ddp_model.module.state_dict(), checkpoint_filename)
|
||||||
|
else:
|
||||||
|
if epoch >= cfg.warmup_epochs:
|
||||||
|
patience_counter += 1
|
||||||
|
print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}")
|
||||||
|
|
||||||
|
stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device)
|
||||||
|
else:
|
||||||
|
stop_flag = torch.zeros(1, device=device)
|
||||||
|
|
||||||
|
# Broadcast stop flag and best loss to all ranks
|
||||||
|
dist.broadcast(stop_flag, src=0)
|
||||||
|
if stop_flag.item() > 0:
|
||||||
|
if rank == 0:
|
||||||
|
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save best model at the end (rank 0)
|
||||||
|
if rank == 0 and best_val_loss != float('inf'):
|
||||||
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||||
|
state = torch.load(checkpoint_filename, map_location='cpu')
|
||||||
|
ddp_model.module.load_state_dict(state)
|
||||||
|
print(f"Saving final best model to {model_filename}")
|
||||||
|
torch.save(ddp_model.module.state_dict(), model_filename)
|
||||||
|
|
||||||
|
# Save losses to file
|
||||||
|
losses_filename = f"losses_{model_suffix}.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
|
# Plot curves
|
||||||
|
num_epochs = len(train_losses_total)
|
||||||
|
epochs = range(1, num_epochs + 1)
|
||||||
|
plt.figure(figsize=(18, 5))
|
||||||
|
|
||||||
|
plt.subplot(1, 3, 1)
|
||||||
|
plt.plot(epochs, train_losses_ce, label='Train CE')
|
||||||
|
plt.plot(epochs, val_losses_ce, label='Val CE')
|
||||||
|
plt.title('Cross-Entropy Loss')
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend(); plt.grid(True)
|
||||||
|
|
||||||
|
plt.subplot(1, 3, 2)
|
||||||
|
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
||||||
|
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
||||||
|
plt.title('Survival Loss')
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend(); plt.grid(True)
|
||||||
|
|
||||||
|
plt.subplot(1, 3, 3)
|
||||||
|
plt.plot(epochs, train_losses_total, label='Train Total')
|
||||||
|
plt.plot(epochs, val_losses_total, label='Val Total')
|
||||||
|
plt.title('Total Loss')
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend(); plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('loss_curves.png')
|
||||||
|
print("\nLoss curves saved to loss_curves.png")
|
||||||
|
|
||||||
|
cleanup_ddp()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
218
train_iter.py
Normal file
218
train_iter.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from models import TimeAwareGPT2, CombinedLoss
|
||||||
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
class TrainConfig:
|
||||||
|
# Data parameters
|
||||||
|
train_data_path = 'ukb_real_train.bin'
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
block_length = 48 # Sequence length
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
n_embd = 120
|
||||||
|
n_layer = 12
|
||||||
|
n_head = 12
|
||||||
|
pdrop = 0.0
|
||||||
|
token_pdrop = 0.0
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
max_iter = 200000
|
||||||
|
batch_size = 128
|
||||||
|
lr_initial = 6e-4
|
||||||
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
|
warmup_iter = 1000
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
# 0 = padding, 1 = "no event"
|
||||||
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# --- Main Training Script ---
|
||||||
|
def main():
|
||||||
|
config = TrainConfig()
|
||||||
|
|
||||||
|
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
|
||||||
|
|
||||||
|
# --- 0. Save Configuration ---
|
||||||
|
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
|
||||||
|
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||||
|
with open(config_filename, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
print(f"Configuration saved to {config_filename}")
|
||||||
|
|
||||||
|
# --- 1. Data Loading ---
|
||||||
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
|
||||||
|
# Infer vocab_size from the data (max label + 1)
|
||||||
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||||
|
print(f"Inferred vocabulary size: {vocab_size}")
|
||||||
|
|
||||||
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||||
|
train_iter_loader = iter(itertools.cycle(train_loader))
|
||||||
|
|
||||||
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
|
print(f"Initializing model on {config.device}...")
|
||||||
|
model = TimeAwareGPT2(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=config.n_embd,
|
||||||
|
n_layer=config.n_layer,
|
||||||
|
n_head=config.n_head,
|
||||||
|
pdrop=config.pdrop,
|
||||||
|
token_pdrop=config.token_pdrop
|
||||||
|
).to(config.device)
|
||||||
|
|
||||||
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
|
||||||
|
|
||||||
|
# --- 3. Training Loop ---
|
||||||
|
|
||||||
|
# Lists to store losses
|
||||||
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||||
|
|
||||||
|
print("Starting training...")
|
||||||
|
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
|
||||||
|
for iter_num in pbar:
|
||||||
|
# --- Learning Rate Scheduling ---
|
||||||
|
if iter_num < config.warmup_iter:
|
||||||
|
lr = config.lr_initial
|
||||||
|
else:
|
||||||
|
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
|
||||||
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# --- Training Step ---
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
event_seq, time_seq = next(train_iter_loader)
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
# Prepare inputs and targets
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_losses_ce.append(loss_ce.item())
|
||||||
|
train_losses_surv.append(loss_survival.item())
|
||||||
|
train_losses_total.append(loss.item())
|
||||||
|
|
||||||
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||||
|
|
||||||
|
print("\nTraining finished.")
|
||||||
|
|
||||||
|
# --- 4. Final Validation ---
|
||||||
|
print("Running final validation...")
|
||||||
|
model.eval()
|
||||||
|
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||||
|
val_steps = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
|
||||||
|
for event_seq, time_seq in pbar_val:
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
|
||||||
|
val_loss_ce_acc += loss_ce.item()
|
||||||
|
val_loss_surv_acc += loss_survival.item()
|
||||||
|
val_steps += 1
|
||||||
|
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||||
|
|
||||||
|
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||||
|
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||||
|
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||||
|
|
||||||
|
print(f"Final Validation Summary: \n"
|
||||||
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
|
||||||
|
|
||||||
|
# --- 5. Save Model ---
|
||||||
|
print(f"Saving final model to {model_filename}")
|
||||||
|
torch.save(model.state_dict(), model_filename)
|
||||||
|
|
||||||
|
# --- 6. Save and Plot Losses ---
|
||||||
|
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
|
# Plot and Save Loss Curves
|
||||||
|
iterations = range(1, len(train_losses_total) + 1)
|
||||||
|
|
||||||
|
plt.figure(figsize=(18, 5))
|
||||||
|
|
||||||
|
# Plot CE Loss
|
||||||
|
plt.subplot(1, 3, 1)
|
||||||
|
plt.plot(iterations, train_losses_ce, label='Train CE')
|
||||||
|
plt.title('Cross-Entropy Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Survival Loss
|
||||||
|
plt.subplot(1, 3, 2)
|
||||||
|
plt.plot(iterations, train_losses_surv, label='Train Survival')
|
||||||
|
plt.title('Survival Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Total Loss
|
||||||
|
plt.subplot(1, 3, 3)
|
||||||
|
plt.plot(iterations, train_losses_total, label='Train Total')
|
||||||
|
plt.title('Total Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('loss_curves_iter.png')
|
||||||
|
print("\nLoss curves saved to loss_curves_iter.png")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
163
utils.py
163
utils.py
@@ -1,7 +1,11 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv
|
||||||
|
|
||||||
|
|
||||||
class PatientEventDataset(torch.utils.data.Dataset):
|
class PatientEventDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
@@ -39,17 +43,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
|||||||
"""
|
"""
|
||||||
return len(self.patient_ids)
|
return len(self.patient_ids)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx):
|
||||||
"""
|
"""
|
||||||
Retrieves, processes, and returns a single patient's event sequence.
|
Retrieves, processes, and returns a single patient's event sequence,
|
||||||
|
or a list of sequences if a slice is provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx (int): The index of the patient to retrieve.
|
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
If idx is an int, a tuple of two torch.long tensors:
|
||||||
both of shape (block_length,).
|
(event_sequence, time_sequence), both of shape (block_length,).
|
||||||
|
If idx is a slice, a list of such tuples.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
return [self[i] for i in range(*idx.indices(len(self)))]
|
||||||
|
|
||||||
# 1. Retrieve and Sort
|
# 1. Retrieve and Sort
|
||||||
patient_id = self.patient_ids[idx]
|
patient_id = self.patient_ids[idx]
|
||||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||||
@@ -102,3 +111,147 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
|||||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||||
|
|
||||||
return event_tensor, time_tensor
|
return event_tensor, time_tensor
|
||||||
|
|
||||||
|
def load_model(config_path: str, device: str = 'cpu'):
|
||||||
|
"""
|
||||||
|
Load a trained model based on the training configuration, inferring the
|
||||||
|
checkpoint filename from the configuration.
|
||||||
|
|
||||||
|
According to train.py, models may be either 'TimeAwareGPT2' or
|
||||||
|
'TimeAwareGPT2Learnable'. This function:
|
||||||
|
- Reads the config JSON to get architecture hyperparameters
|
||||||
|
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
||||||
|
- Infers the checkpoint path from the config values
|
||||||
|
- Infers vocab_size from the checkpoint
|
||||||
|
- Loads weights and returns the model in eval mode on the requested device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the JSON configuration file saved during training.
|
||||||
|
device: 'cpu' or 'cuda'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: Loaded model ready for inference.
|
||||||
|
"""
|
||||||
|
# 1) Read config
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
# Access config entries with attribute-style access while staying tolerant to missing keys
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __getattr__(self, item):
|
||||||
|
try:
|
||||||
|
return self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(item)
|
||||||
|
|
||||||
|
config = AttrDict(config_dict)
|
||||||
|
|
||||||
|
# 2) Decide model class (train.py supports two variants)
|
||||||
|
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
|
||||||
|
model_cls = {
|
||||||
|
'TimeAwareGPT2': TimeAwareGPT2,
|
||||||
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
||||||
|
'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv,
|
||||||
|
}.get(model_name, TimeAwareGPT2)
|
||||||
|
|
||||||
|
# 3) Infer checkpoint filename from config
|
||||||
|
n_embd = getattr(config, 'n_embd')
|
||||||
|
n_layer = getattr(config, 'n_layer')
|
||||||
|
n_head = getattr(config, 'n_head')
|
||||||
|
|
||||||
|
# Newer naming (includes model_name) used by train.py when model_name is present
|
||||||
|
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||||
|
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
|
||||||
|
|
||||||
|
# Older naming (without model_name) matches existing repo files
|
||||||
|
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
||||||
|
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
|
||||||
|
|
||||||
|
# Prefer file that exists on disk
|
||||||
|
if os.path.exists(ckpt_with_model):
|
||||||
|
model_path = ckpt_with_model
|
||||||
|
elif os.path.exists(ckpt_legacy):
|
||||||
|
model_path = ckpt_legacy
|
||||||
|
else:
|
||||||
|
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
|
||||||
|
model_path = ckpt_with_model
|
||||||
|
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
|
||||||
|
|
||||||
|
# 4) Infer vocab_size from checkpoint
|
||||||
|
state_preview = torch.load(model_path, map_location='cpu')
|
||||||
|
if 'wte.weight' in state_preview:
|
||||||
|
vocab_size = state_preview['wte.weight'].shape[0]
|
||||||
|
elif 'head.weight' in state_preview:
|
||||||
|
vocab_size = state_preview['head.weight'].shape[0]
|
||||||
|
else:
|
||||||
|
candidate = None
|
||||||
|
for k, v in state_preview.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
||||||
|
V = max(v.shape)
|
||||||
|
if candidate is None or V > candidate:
|
||||||
|
candidate = V
|
||||||
|
if candidate is None:
|
||||||
|
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
|
||||||
|
vocab_size = candidate
|
||||||
|
|
||||||
|
# 5) Build model from config (tolerant to missing fields)
|
||||||
|
pdrop = getattr(config, 'pdrop', 0.1)
|
||||||
|
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
||||||
|
|
||||||
|
model = model_cls(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=n_embd,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
pdrop=pdrop,
|
||||||
|
token_pdrop=token_pdrop,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# 6) Load weights
|
||||||
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"Warning: Missing keys when loading state_dict: {missing}")
|
||||||
|
if unexpected:
|
||||||
|
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
try:
|
||||||
|
num_params_m = model.get_num_params()
|
||||||
|
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (PatientEventDataset): The dataset to retrieve data from.
|
||||||
|
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
||||||
|
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
||||||
|
These tokens will be replaced with -100. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing four tensors:
|
||||||
|
- input_events: (batch_size, sequence_length - 1)
|
||||||
|
- input_tims: (batch_size, sequence_length - 1)
|
||||||
|
- target_events: (batch_size, sequence_length - 1)
|
||||||
|
- target_times: (batch_size, sequence_length - 1)
|
||||||
|
"""
|
||||||
|
batch_data = dataset[batch_slice]
|
||||||
|
|
||||||
|
input_events = [item[0][:-1] for item in batch_data]
|
||||||
|
input_tims = [item[1][:-1] for item in batch_data]
|
||||||
|
target_events = [item[0][1:] for item in batch_data]
|
||||||
|
target_times = [item[1][1:] for item in batch_data]
|
||||||
|
|
||||||
|
input_events = torch.stack(input_events)
|
||||||
|
input_tims = torch.stack(input_tims)
|
||||||
|
target_events = torch.stack(target_events)
|
||||||
|
target_times = torch.stack(target_times)
|
||||||
|
|
||||||
|
return input_events, input_tims, target_events, target_times
|
Reference in New Issue
Block a user