Compare commits
18 Commits
d4d25ac9c7
...
main
Author | SHA1 | Date | |
---|---|---|---|
e348086e52 | |||
a8aa5a2bd6 | |||
ddb7dbfc67 | |||
88cccdad2e | |||
8f44018bae | |||
1c9e2a2fb3 | |||
6b782b86e1 | |||
9a9de170d1 | |||
7e57e5d3b1 | |||
14865ac5b6 | |||
dbc3000192 | |||
082c719975 | |||
a631ac6d59 | |||
f7356b183c | |||
3390bc025e | |||
a832a45c62 | |||
d760c45baf | |||
053f86f4da |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,7 +5,7 @@
|
||||
__pycache__/
|
||||
|
||||
# Model checkpoints
|
||||
best_model_checkpoint.pt
|
||||
*.pt
|
||||
|
||||
# Large data files
|
||||
ukb_delphi.txt
|
||||
|
BIN
best_model_n_embd_120_n_layer_12_n_head_12.pt
Normal file
BIN
best_model_n_embd_120_n_layer_12_n_head_12.pt
Normal file
Binary file not shown.
BIN
best_model_n_embd_256_n_layer_16_n_head_16.pt
Normal file
BIN
best_model_n_embd_256_n_layer_16_n_head_16.pt
Normal file
Binary file not shown.
18
config_n_embd_120_n_layer_12_n_head_12.json
Normal file
18
config_n_embd_120_n_layer_12_n_head_12.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"n_layer": 12,
|
||||
"n_embd": 120,
|
||||
"n_head": 12,
|
||||
"max_epoch": 200,
|
||||
"batch_size": 128,
|
||||
"lr_initial": 0.0006,
|
||||
"lr_final": 6e-05,
|
||||
"weight_decay": 0.2,
|
||||
"warmup_epochs": 10,
|
||||
"early_stopping_patience": 10,
|
||||
"pdrop": 0.0,
|
||||
"token_pdrop": 0.0,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.99
|
||||
]
|
||||
}
|
18
config_n_embd_256_n_layer_16_n_head_16.json
Normal file
18
config_n_embd_256_n_layer_16_n_head_16.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"n_layer": 16,
|
||||
"n_embd": 256,
|
||||
"n_head": 16,
|
||||
"max_epoch": 200,
|
||||
"batch_size": 128,
|
||||
"lr_initial": 0.0006,
|
||||
"lr_final": 6e-05,
|
||||
"weight_decay": 0.2,
|
||||
"warmup_epochs": 10,
|
||||
"early_stopping_patience": 10,
|
||||
"pdrop": 0.0,
|
||||
"token_pdrop": 0.0,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.99
|
||||
]
|
||||
}
|
1271
delphi_labels_chapters_colours_icd.csv
Normal file
1271
delphi_labels_chapters_colours_icd.csv
Normal file
File diff suppressed because it is too large
Load Diff
499
evaluate_auc.py
Normal file
499
evaluate_auc.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import scipy.stats
|
||||
import scipy
|
||||
import warnings
|
||||
import torch
|
||||
from models import TimeAwareGPT2
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
from utils import load_model, get_batch, PatientEventDataset
|
||||
from pathlib import Path
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def auc(x1, x2):
|
||||
n1 = len(x1)
|
||||
n2 = len(x2)
|
||||
R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1
|
||||
U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1
|
||||
if n1 == 0 or n2 == 0:
|
||||
return np.nan
|
||||
return U1 / n1 / n2
|
||||
|
||||
|
||||
def get_common_diseases(delphi_labels, filter_min_total=100):
|
||||
chapters_of_interest = [
|
||||
"I. Infectious Diseases",
|
||||
"II. Neoplasms",
|
||||
"III. Blood & Immune Disorders",
|
||||
"IV. Metabolic Diseases",
|
||||
"V. Mental Disorders",
|
||||
"VI. Nervous System Diseases",
|
||||
"VII. Eye Diseases",
|
||||
"VIII. Ear Diseases",
|
||||
"IX. Circulatory Diseases",
|
||||
"X. Respiratory Diseases",
|
||||
"XI. Digestive Diseases",
|
||||
"XII. Skin Diseases",
|
||||
"XIII. Musculoskeletal Diseases",
|
||||
"XIV. Genitourinary Diseases",
|
||||
"XV. Pregnancy & Childbirth",
|
||||
"XVI. Perinatal Conditions",
|
||||
"XVII. Congenital Abnormalities",
|
||||
"Death",
|
||||
]
|
||||
labels_df = delphi_labels[
|
||||
delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total)
|
||||
]
|
||||
return labels_df["index"].tolist()
|
||||
|
||||
|
||||
def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000):
|
||||
"""
|
||||
Computes bootstrapped AUC estimates using PyTorch on CUDA.
|
||||
|
||||
Parameters:
|
||||
case: 1D tensor of scores for positive cases
|
||||
control: 1D tensor of scores for controls
|
||||
n_bootstrap: Number of bootstrap replicates
|
||||
|
||||
Returns:
|
||||
Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available. This function requires a GPU.")
|
||||
|
||||
# Convert inputs to CUDA tensors
|
||||
if not torch.is_tensor(case):
|
||||
case = torch.tensor(case, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
case = case.to("cuda", dtype=torch.float32)
|
||||
|
||||
if not torch.is_tensor(control):
|
||||
control = torch.tensor(control, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
control = control.to("cuda", dtype=torch.float32)
|
||||
|
||||
n_case = case.size(0)
|
||||
n_control = control.size(0)
|
||||
total = n_case + n_control
|
||||
|
||||
# Generate bootstrap samples
|
||||
boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda")
|
||||
boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda")
|
||||
|
||||
boot_case = case[boot_idx_case]
|
||||
boot_control = control[boot_idx_control]
|
||||
|
||||
combined = torch.cat([boot_case, boot_control], dim=1)
|
||||
|
||||
# Mask to identify case entries
|
||||
mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda")
|
||||
mask[:, :n_case] = True
|
||||
|
||||
# Compute ranks and AUC
|
||||
ranks = combined.argsort(dim=1).argsort(dim=1)
|
||||
case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1)
|
||||
min_case_rank_sum = n_case * (n_case - 1) / 2.0
|
||||
U = case_ranks_sum - min_case_rank_sum
|
||||
aucs = U / (n_case * n_control)
|
||||
return aucs.cpu().tolist()
|
||||
|
||||
|
||||
# AUC comparison adapted from
|
||||
# https://github.com/Netflix/vmaf/
|
||||
def compute_midrank(x):
|
||||
"""Computes midranks.
|
||||
Args:
|
||||
x - a 1D numpy array
|
||||
Returns:
|
||||
array of midranks
|
||||
"""
|
||||
J = np.argsort(x)
|
||||
Z = x[J]
|
||||
N = len(x)
|
||||
T = np.zeros(N, dtype=np.float32)
|
||||
i = 0
|
||||
while i < N:
|
||||
j = i
|
||||
while j < N and Z[j] == Z[i]:
|
||||
j += 1
|
||||
T[i:j] = 0.5 * (i + j - 1)
|
||||
i = j
|
||||
T2 = np.empty(N, dtype=np.float32)
|
||||
# Note(kazeevn) +1 is due to Python using 0-based indexing
|
||||
# instead of 1-based in the AUC formula in the paper
|
||||
T2[J] = T + 1
|
||||
return T2
|
||||
|
||||
|
||||
def fastDeLong(predictions_sorted_transposed, label_1_count):
|
||||
"""
|
||||
The fast version of DeLong's method for computing the covariance of
|
||||
unadjusted AUC.
|
||||
Args:
|
||||
predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
|
||||
sorted such as the examples with label "1" are first
|
||||
Returns:
|
||||
(AUC value, DeLong covariance)
|
||||
Reference:
|
||||
@article{sun2014fast,
|
||||
title={Fast Implementation of DeLong's Algorithm for
|
||||
Comparing the Areas Under Correlated Receiver Operating Characteristic Curves},
|
||||
author={Xu Sun and Weichao Xu},
|
||||
journal={IEEE Signal Processing Letters},
|
||||
volume={21},
|
||||
number={11},
|
||||
pages={1389--1393},
|
||||
year={2014},
|
||||
publisher={IEEE}
|
||||
}
|
||||
"""
|
||||
# Short variables are named as they are in the paper
|
||||
m = label_1_count
|
||||
n = predictions_sorted_transposed.shape[1] - m
|
||||
positive_examples = predictions_sorted_transposed[:, :m]
|
||||
negative_examples = predictions_sorted_transposed[:, m:]
|
||||
k = predictions_sorted_transposed.shape[0]
|
||||
|
||||
tx = np.empty([k, m], dtype=np.float32)
|
||||
ty = np.empty([k, n], dtype=np.float32)
|
||||
tz = np.empty([k, m + n], dtype=np.float32)
|
||||
for r in range(k):
|
||||
tx[r, :] = compute_midrank(positive_examples[r, :])
|
||||
ty[r, :] = compute_midrank(negative_examples[r, :])
|
||||
tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
|
||||
aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
|
||||
v01 = (tz[:, :m] - tx[:, :]) / n
|
||||
v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
|
||||
sx = np.cov(v01)
|
||||
sy = np.cov(v10)
|
||||
delongcov = sx / m + sy / n
|
||||
return aucs, delongcov
|
||||
|
||||
|
||||
def compute_ground_truth_statistics(ground_truth):
|
||||
assert np.array_equal(np.unique(ground_truth), [0, 1])
|
||||
order = (-ground_truth).argsort()
|
||||
label_1_count = int(ground_truth.sum())
|
||||
return order, label_1_count
|
||||
|
||||
|
||||
def get_auc_delong_var(healthy_scores, diseased_scores):
|
||||
"""
|
||||
Computes ROC AUC value and variance using DeLong's method
|
||||
|
||||
Args:
|
||||
healthy_scores: Values for class 0 (healthy/controls)
|
||||
diseased_scores: Values for class 1 (diseased/cases)
|
||||
Returns:
|
||||
AUC value and variance
|
||||
"""
|
||||
# Create ground truth labels (1 for diseased, 0 for healthy)
|
||||
ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores))
|
||||
predictions = np.concatenate([diseased_scores, healthy_scores])
|
||||
|
||||
# Compute statistics needed for DeLong method
|
||||
order, label_1_count = compute_ground_truth_statistics(ground_truth)
|
||||
predictions_sorted_transposed = predictions[np.newaxis, order]
|
||||
|
||||
# Calculate AUC and covariance
|
||||
aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count)
|
||||
assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
|
||||
|
||||
return aucs[0], delongcov
|
||||
|
||||
|
||||
def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False):
|
||||
age_step = age_groups[1] - age_groups[0]
|
||||
|
||||
# Indexes of cases with disease k
|
||||
wk = np.where(d[2] == k)
|
||||
|
||||
if len(wk[0]) < 2:
|
||||
return None
|
||||
|
||||
# For controls, we need to exclude cases with disease k
|
||||
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
|
||||
|
||||
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
|
||||
|
||||
# We need to take into account the offset t and use the tokens for prediction that are at least t before the event
|
||||
if precomputed_idx is None:
|
||||
pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1
|
||||
else:
|
||||
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
|
||||
|
||||
valid_indices = pred_idx != -1
|
||||
pred_idx = pred_idx[valid_indices]
|
||||
wall = (wall[0][valid_indices], wall[1][valid_indices])
|
||||
|
||||
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
|
||||
zk = d[3][wall] # Target times
|
||||
|
||||
x = p[..., j][(wall[0], pred_idx)]
|
||||
|
||||
p_idx = wall[0]
|
||||
|
||||
out = []
|
||||
|
||||
for i, aa in enumerate(age_groups):
|
||||
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
|
||||
|
||||
if not np.any(a):
|
||||
continue
|
||||
|
||||
selected_groups = p_idx[a]
|
||||
_, unique_indices = np.unique(selected_groups, return_index=True)
|
||||
|
||||
a_filtered = a[a]
|
||||
a_filtered[:] = False
|
||||
a_filtered[unique_indices] = True
|
||||
a[a] = a_filtered
|
||||
|
||||
is_case = np.zeros_like(x, dtype=bool)
|
||||
is_case[:len(wk[0])] = True
|
||||
|
||||
control = x[~is_case & a]
|
||||
case = x[is_case & a]
|
||||
|
||||
if len(control) == 0 or len(case) == 0:
|
||||
continue
|
||||
|
||||
if use_delong:
|
||||
auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case)
|
||||
auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong}
|
||||
else:
|
||||
auc_delong_dict = {}
|
||||
|
||||
if n_bootstrap > 1:
|
||||
aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap)
|
||||
|
||||
for bootstrap_idx in range(n_bootstrap):
|
||||
y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx]
|
||||
out_item = {
|
||||
"token": k,
|
||||
"auc": y,
|
||||
"age": aa,
|
||||
"n_healthy": len(control),
|
||||
"n_diseased": len(case),
|
||||
}
|
||||
out.append(out_item | auc_delong_dict)
|
||||
if n_bootstrap > 1:
|
||||
out_item["bootstrap_idx"] = bootstrap_idx
|
||||
return out
|
||||
|
||||
|
||||
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
|
||||
all_aucs = []
|
||||
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
||||
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
||||
p_sex = p100k[sex_mask]
|
||||
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
|
||||
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
||||
for j, k in tqdm(
|
||||
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
||||
):
|
||||
out = get_calibration_auc(
|
||||
j,
|
||||
k,
|
||||
d100k_sex,
|
||||
p_sex,
|
||||
age_groups=age_groups,
|
||||
offset=offset,
|
||||
precomputed_idx=precomputed_idx_subset,
|
||||
n_bootstrap=n_bootstrap,
|
||||
use_delong=True,
|
||||
)
|
||||
if out is None:
|
||||
continue
|
||||
for out_item in out:
|
||||
out_item["sex"] = sex
|
||||
all_aucs.append(out_item)
|
||||
return all_aucs
|
||||
|
||||
|
||||
# New internal function that performs the AUC evaluation pipeline.
|
||||
def evaluate_auc_pipeline(
|
||||
model,
|
||||
d100k,
|
||||
output_path,
|
||||
delphi_labels,
|
||||
diseases_of_interest=None,
|
||||
filter_min_total=100,
|
||||
disease_chunk_size=200,
|
||||
age_groups=np.arange(40, 80, 5),
|
||||
offset=0.1,
|
||||
batch_size=256,
|
||||
device="cpu",
|
||||
seed=1337,
|
||||
n_bootstrap=1,
|
||||
meta_info={},
|
||||
n_jobs=-1,
|
||||
):
|
||||
"""
|
||||
Runs the AUC evaluation pipeline.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The loaded model set to eval().
|
||||
d100k (tuple): Data batch from get_batch.
|
||||
delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv").
|
||||
output_path (str | None): Directory where CSV files will be written. If None, files will not be saved.
|
||||
diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used.
|
||||
filter_min_total (int): Minimum total token count to include a token.
|
||||
disease_chunk_size (int): Maximum chunk size for processing diseases.
|
||||
age_groups (np.ndarray): Age groups to use in calibration.
|
||||
offset (float): Offset used in get_calibration_auc.
|
||||
batch_size (int): Batch size for model forwarding.
|
||||
device (str): Device identifier.
|
||||
seed (int): Random seed for reproducibility.
|
||||
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
|
||||
n_jobs (int): Number of parallel jobs to run.
|
||||
Returns:
|
||||
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
|
||||
"""
|
||||
|
||||
assert n_bootstrap > 0, "n_bootstrap must be greater than 0"
|
||||
|
||||
# Set random seeds
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Get common diseases
|
||||
if diseases_of_interest is None:
|
||||
diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total)
|
||||
|
||||
# Split diseases into chunks for processing
|
||||
num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size
|
||||
diseases_chunks = np.array_split(diseases_of_interest, num_chunks)
|
||||
|
||||
# Precompute prediction indices for calibration
|
||||
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
|
||||
|
||||
p100k = []
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
for dd in tqdm(
|
||||
zip(*[torch.split(x, batch_size) for x in d100k]),
|
||||
desc=f"Model inference",
|
||||
total=d100k[0].shape[0] // batch_size + 1,
|
||||
):
|
||||
dd = [x.to(device) for x in dd]
|
||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||
p100k.append(outputs.astype("float16"))
|
||||
p100k = np.vstack(p100k)
|
||||
|
||||
results = Parallel(n_jobs=n_jobs)(
|
||||
delayed(process_chunk)(
|
||||
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
|
||||
)
|
||||
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
|
||||
)
|
||||
|
||||
all_aucs = [item for sublist in results for item in sublist]
|
||||
|
||||
df_auc_unpooled = pd.DataFrame(all_aucs)
|
||||
|
||||
for key, value in meta_info.items():
|
||||
df_auc_unpooled[key] = value
|
||||
|
||||
delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']]
|
||||
df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner")
|
||||
|
||||
def aggregate_age_brackets_delong(group):
|
||||
# For normal distributions, when averaging n of them:
|
||||
# The variance of the sum is the sum of variances
|
||||
# The variance of the average is the sum of variances divided by n^2
|
||||
n = len(group)
|
||||
mean = group['auc_delong'].mean()
|
||||
# Since we're taking the average, divide combined variance by n^2
|
||||
var = group['auc_variance_delong'].sum() / (n**2)
|
||||
return pd.Series({
|
||||
'auc': mean,
|
||||
'auc_variance_delong': var,
|
||||
'n_samples': n,
|
||||
'n_diseased': group['n_diseased'].sum(),
|
||||
'n_healthy': group['n_healthy'].sum(),
|
||||
})
|
||||
|
||||
print('Using DeLong method to calculate AUC confidence intervals..')
|
||||
|
||||
df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index()
|
||||
df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner")
|
||||
|
||||
if output_path is not None:
|
||||
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||
df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False)
|
||||
df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False)
|
||||
|
||||
return df_auc_unpooled_merged, df_auc_merged
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
||||
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
|
||||
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
||||
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
||||
parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction")
|
||||
# Optional filtering/chunking parameters:
|
||||
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
|
||||
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = args.model_name
|
||||
output_path = f'auc_evaluation_{model_name}'
|
||||
dataset_subset_size = args.dataset_subset_size
|
||||
offset = args.offset
|
||||
|
||||
# Create output folder if it doesn't exist.
|
||||
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
device = "cuda"
|
||||
seed = 1337
|
||||
|
||||
# Load model checkpoint and initialize model.
|
||||
model = load_model(f'config_{model_name}.json',
|
||||
f'best_model_{model_name}.pt',
|
||||
1270)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
# Load training and validation data.
|
||||
|
||||
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
|
||||
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
block_length = 128
|
||||
val_dataset = PatientEventDataset(val_data_arr, block_length)
|
||||
|
||||
if dataset_subset_size == -1:
|
||||
dataset_subset_size = len(val_dataset)
|
||||
|
||||
# Get a subset batch for evaluation.
|
||||
d100k = get_batch(val_dataset, slice(dataset_subset_size))
|
||||
|
||||
# Load labels (external) to be passed in.
|
||||
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
|
||||
|
||||
# Call the internal evaluation function.
|
||||
df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline(
|
||||
model,
|
||||
d100k,
|
||||
output_path,
|
||||
delphi_labels,
|
||||
diseases_of_interest=None,
|
||||
filter_min_total=args.filter_min_total,
|
||||
disease_chunk_size=args.disease_chunk_size,
|
||||
device=device,
|
||||
seed=seed,
|
||||
offset=offset,
|
||||
n_bootstrap=args.n_bootstrap,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
552
evaluate_models.ipynb
Normal file
552
evaluate_models.ipynb
Normal file
File diff suppressed because one or more lines are too long
65
models.py
65
models.py
@@ -177,9 +177,10 @@ class TimeAwareGPT2(nn.Module):
|
||||
A time-aware GPT-2 model with custom temporal features.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
|
||||
super().__init__()
|
||||
self.token_pdrop = token_pdrop
|
||||
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||
|
||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||
@@ -234,6 +235,58 @@ class TimeAwareGPT2(nn.Module):
|
||||
"""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
|
||||
"""
|
||||
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
|
||||
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if termination_tokens is None:
|
||||
termination_tokens = [1269]
|
||||
|
||||
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||
mask_time = -10000
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
logits = self(x, t)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.ignore_tokens:
|
||||
logits[:, self.ignore_tokens] = -torch.inf
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
logits = logits.scatter(1, fill, -torch.inf)
|
||||
|
||||
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
|
||||
t_next_val, idx_next = t_next_dist.min(1)
|
||||
|
||||
idx_next = idx_next.unsqueeze(1)
|
||||
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||
|
||||
x = torch.cat((x, idx_next), dim=1)
|
||||
t = torch.cat((t, age_next), dim=1)
|
||||
|
||||
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||
break
|
||||
|
||||
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||
|
||||
final_logits = self(x, t)
|
||||
x[pad] = 0
|
||||
t[pad] = mask_time
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
|
||||
|
||||
return x, t, final_logits
|
||||
|
||||
class CovariateAwareGPT2(nn.Module):
|
||||
"""
|
||||
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
||||
@@ -367,8 +420,12 @@ class CombinedLoss(nn.Module):
|
||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||
loss_ce = per_element_ce[mask].mean()
|
||||
|
||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||||
loss_survival = per_element_survival[mask].mean()
|
||||
# Survival loss based on exponential log-likelihood
|
||||
t_min = 0.1
|
||||
lse = torch.logsumexp(logits, dim=-1)
|
||||
lse = -torch.log(torch.exp(-lse) + t_min)
|
||||
ldt = -torch.log(t + t_min)
|
||||
loss_dt = -(lse - torch.exp(lse - ldt))
|
||||
loss_survival = loss_dt[mask].mean()
|
||||
|
||||
return loss_ce, loss_survival
|
||||
|
@@ -2,3 +2,4 @@ torch
|
||||
numpy
|
||||
tqdm
|
||||
matplotlib
|
||||
joblib
|
||||
|
44
train.py
44
train.py
@@ -6,6 +6,8 @@ import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import argparse
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
@@ -29,8 +31,10 @@ class TrainConfig:
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
weight_decay = 2e-1
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 10
|
||||
betas = (0.9, 0.99)
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
@@ -41,11 +45,49 @@ class TrainConfig:
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
|
||||
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
|
||||
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
|
||||
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
|
||||
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
|
||||
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
|
||||
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
|
||||
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
|
||||
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
|
||||
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
|
||||
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
|
||||
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
|
||||
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TrainConfig()
|
||||
config.n_layer = args.n_layer
|
||||
config.n_embd = args.n_embd
|
||||
config.n_head = args.n_head
|
||||
config.max_epoch = args.max_epoch
|
||||
config.batch_size = args.batch_size
|
||||
config.lr_initial = args.lr_initial
|
||||
config.lr_final = args.lr_final
|
||||
config.weight_decay = args.weight_decay
|
||||
config.warmup_epochs = args.warmup_epochs
|
||||
config.early_stopping_patience = args.early_stopping_patience
|
||||
config.pdrop = args.pdrop
|
||||
config.token_pdrop = args.token_pdrop
|
||||
config.betas = tuple(args.betas)
|
||||
|
||||
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||
with open(config_filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration saved to {config_filename}")
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
@@ -75,7 +117,7 @@ def main():
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
|
400
train_dpp.py
400
train_dpp.py
@@ -1,400 +0,0 @@
|
||||
# train.py (DDP-ready)
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 24 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
n_layer = 8
|
||||
n_head = 8
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 5
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
ignored_token_ids = [0, 1]
|
||||
|
||||
# System parameters (device 将在 main() 内按 local_rank 动态设置)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
def setup_distributed(backend: str = "nccl"):
|
||||
"""
|
||||
如果由 torchrun 启动且 WORLD_SIZE>1,则初始化分布式。
|
||||
返回 (is_distributed, world_size, rank, local_rank)
|
||||
"""
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
is_distributed = world_size > 1
|
||||
if is_distributed:
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
rank = dist.get_rank()
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
rank = 0
|
||||
local_rank = 0
|
||||
return is_distributed, world_size, rank, local_rank
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def all_reduce_mean(value: float, device, world_size: int):
|
||||
"""
|
||||
value 是 Python float(本进程的和/均值),返回所有进程平均后的 float。
|
||||
"""
|
||||
tensor = torch.tensor([value], dtype=torch.float32, device=device)
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor /= world_size
|
||||
return float(tensor.item())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 分布式初始化
|
||||
is_dist, world_size, rank, local_rank = setup_distributed(args.backend)
|
||||
|
||||
# 基本环境
|
||||
torch.manual_seed(args.seed + rank)
|
||||
np.random.seed(args.seed + rank)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = TrainConfig()
|
||||
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
config.device = device
|
||||
|
||||
is_main = (rank == 0)
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
if is_main:
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
# Infer vocab_size from the data (max label + 1)
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
if is_main:
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
# 分布式采样器
|
||||
if is_dist:
|
||||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
||||
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
||||
else:
|
||||
train_sampler = None
|
||||
val_sampler = None
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
persistent_workers=True if 4 > 0 else False,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
sampler=val_sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
persistent_workers=True if 4 > 0 else False,
|
||||
)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
if is_main:
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(device)
|
||||
|
||||
if is_main and hasattr(model, "get_num_params"):
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
|
||||
# DDP 包装
|
||||
if is_dist:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[local_rank] if device.type == "cuda" else None,
|
||||
output_device=local_rank if device.type == "cuda" else None,
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
# 只在主进程收集与画图
|
||||
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
||||
|
||||
if is_main:
|
||||
print("Starting training...")
|
||||
|
||||
stop_training = False
|
||||
|
||||
for epoch in range(config.max_epoch):
|
||||
# 设置 epoch 给分布式采样器,确保跨 epoch shuffle
|
||||
if is_dist:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
# --- Learning Rate Scheduling ---
|
||||
if epoch < config.warmup_epochs:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# --- Training Phase ---
|
||||
if is_main:
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||
else:
|
||||
pbar = train_loader # 非主进程禁用 tqdm
|
||||
|
||||
model.train()
|
||||
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||
train_steps = 0
|
||||
|
||||
for batch in pbar:
|
||||
event_seq, time_seq = batch
|
||||
event_seq = event_seq.to(device, non_blocking=True)
|
||||
time_seq = time_seq.to(device, non_blocking=True)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
# Forward pass
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss_ce_acc += float(loss_ce.item())
|
||||
train_loss_surv_acc += float(loss_survival.item())
|
||||
train_steps += 1
|
||||
|
||||
if is_main and isinstance(pbar, tqdm.tqdm):
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
# 进程内均值
|
||||
avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1)
|
||||
avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1)
|
||||
|
||||
# 所有进程平均
|
||||
if is_dist:
|
||||
avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size)
|
||||
avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size)
|
||||
else:
|
||||
avg_train_loss_ce = avg_train_loss_ce_local
|
||||
avg_train_loss_surv = avg_train_loss_surv_local
|
||||
|
||||
if is_main:
|
||||
train_losses_ce.append(avg_train_loss_ce)
|
||||
train_losses_surv.append(avg_train_loss_surv)
|
||||
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
|
||||
|
||||
# --- Validation Phase ---
|
||||
if is_main:
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||
else:
|
||||
pbar_val = val_loader
|
||||
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in pbar_val:
|
||||
event_seq, time_seq = batch
|
||||
event_seq = event_seq.to(device, non_blocking=True)
|
||||
time_seq = time_seq.to(device, non_blocking=True)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += float(loss_ce.item())
|
||||
val_loss_surv_acc += float(loss_survival.item())
|
||||
val_steps += 1
|
||||
|
||||
if is_main and isinstance(pbar_val, tqdm.tqdm):
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1)
|
||||
avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1)
|
||||
|
||||
if is_dist:
|
||||
avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size)
|
||||
avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size)
|
||||
else:
|
||||
avg_val_loss_ce = avg_val_loss_ce_local
|
||||
avg_val_loss_surv = avg_val_loss_surv_local
|
||||
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
|
||||
# 主进程打印与记录
|
||||
if is_main:
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
val_losses_ce.append(avg_val_loss_ce)
|
||||
val_losses_surv.append(avg_val_loss_surv)
|
||||
val_losses_total.append(total_val_loss)
|
||||
|
||||
# --- Early Stopping Check (基于聚合后的 total_val_loss) ---
|
||||
improved = False
|
||||
if is_main:
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
improved = True
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
# DDP: 保存 module.state_dict()
|
||||
state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict()
|
||||
torch.save(state_dict, 'best_model_checkpoint.pt')
|
||||
else:
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
stop_training = patience_counter >= config.early_stopping_patience
|
||||
|
||||
# 把 improved/stop 广播到所有进程,确保一致退出
|
||||
if is_dist:
|
||||
flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32)
|
||||
dist.broadcast(flag_tensor, src=0)
|
||||
stop_training = bool(int(flag_tensor.item()))
|
||||
|
||||
if stop_training:
|
||||
if is_main:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
# --- Save Best Model at the End (只主进程) ---
|
||||
if is_main:
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
# 为了易用,这里在主进程上重新构建单卡模型加载权重再保存
|
||||
model_single = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to('cpu')
|
||||
model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
|
||||
print("Saving final best model to best_model.pt")
|
||||
torch.save(model_single.state_dict(), 'best_model.pt')
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
# --- Plot and Save Loss Curves ---
|
||||
num_epochs = len(train_losses_total)
|
||||
if num_epochs > 0:
|
||||
epochs = range(1, num_epochs + 1)
|
||||
plt.figure(figsize=(18, 5))
|
||||
|
||||
# Plot CE Loss
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(epochs, train_losses_ce, label='Train CE')
|
||||
plt.plot(epochs, val_losses_ce, label='Val CE')
|
||||
plt.title('Cross-Entropy Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Survival Loss
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
||||
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
||||
plt.title('Survival Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Total Loss
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(epochs, train_losses_total, label='Train Total')
|
||||
plt.plot(epochs, val_losses_total, label='Val Total')
|
||||
plt.title('Total Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('loss_curves.png')
|
||||
print("\nLoss curves saved to loss_curves.png")
|
||||
|
||||
# 清理分布式
|
||||
cleanup_distributed()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
218
train_iter.py
Normal file
218
train_iter.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import itertools
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 48 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 120
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
pdrop = 0.0
|
||||
token_pdrop = 0.0
|
||||
|
||||
# Training parameters
|
||||
max_iter = 200000
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
weight_decay = 2e-1
|
||||
warmup_iter = 1000
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
|
||||
|
||||
# System parameters
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
config = TrainConfig()
|
||||
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
|
||||
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||
with open(config_filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration saved to {config_filename}")
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
# Infer vocab_size from the data (max label + 1)
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||
train_iter_loader = iter(itertools.cycle(train_loader))
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(config.device)
|
||||
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
|
||||
# Lists to store losses
|
||||
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||
|
||||
print("Starting training...")
|
||||
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
|
||||
for iter_num in pbar:
|
||||
# --- Learning Rate Scheduling ---
|
||||
if iter_num < config.warmup_iter:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# --- Training Step ---
|
||||
model.train()
|
||||
|
||||
event_seq, time_seq = next(train_iter_loader)
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
# Forward pass
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_losses_ce.append(loss_ce.item())
|
||||
train_losses_surv.append(loss_survival.item())
|
||||
train_losses_total.append(loss.item())
|
||||
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
print("\nTraining finished.")
|
||||
|
||||
# --- 4. Final Validation ---
|
||||
print("Running final validation...")
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
|
||||
for event_seq, time_seq in pbar_val:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += loss_ce.item()
|
||||
val_loss_surv_acc += loss_survival.item()
|
||||
val_steps += 1
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
|
||||
print(f"Final Validation Summary: \n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
|
||||
|
||||
# --- 5. Save Model ---
|
||||
print(f"Saving final model to {model_filename}")
|
||||
torch.save(model.state_dict(), model_filename)
|
||||
|
||||
# --- 6. Save and Plot Losses ---
|
||||
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
|
||||
with open(losses_filename, 'w') as f:
|
||||
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
|
||||
for i in range(len(train_losses_total)):
|
||||
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
|
||||
print(f"\nLosses saved to {losses_filename}")
|
||||
|
||||
# Plot and Save Loss Curves
|
||||
iterations = range(1, len(train_losses_total) + 1)
|
||||
|
||||
plt.figure(figsize=(18, 5))
|
||||
|
||||
# Plot CE Loss
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(iterations, train_losses_ce, label='Train CE')
|
||||
plt.title('Cross-Entropy Loss')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Survival Loss
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(iterations, train_losses_surv, label='Train Survival')
|
||||
plt.title('Survival Loss')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Total Loss
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(iterations, train_losses_total, label='Train Total')
|
||||
plt.title('Total Loss')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('loss_curves_iter.png')
|
||||
print("\nLoss curves saved to loss_curves_iter.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
95
utils.py
95
utils.py
@@ -2,6 +2,9 @@ import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from models import TimeAwareGPT2
|
||||
|
||||
|
||||
class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
@@ -39,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Retrieves, processes, and returns a single patient's event sequence.
|
||||
Retrieves, processes, and returns a single patient's event sequence,
|
||||
or a list of sequences if a slice is provided.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the patient to retrieve.
|
||||
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||
both of shape (block_length,).
|
||||
If idx is an int, a tuple of two torch.long tensors:
|
||||
(event_sequence, time_sequence), both of shape (block_length,).
|
||||
If idx is a slice, a list of such tuples.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return [self[i] for i in range(*idx.indices(len(self)))]
|
||||
|
||||
# 1. Retrieve and Sort
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||
@@ -102,3 +110,80 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
||||
|
||||
return event_tensor, time_tensor
|
||||
|
||||
def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||
"""
|
||||
Loads a trained TimeAwareGPT2 model from a configuration file and a state dictionary.
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the JSON configuration file.
|
||||
model_path (str): Path to the saved model state dictionary (.pt file).
|
||||
vocab_size (int): The vocabulary size used during training.
|
||||
device (str): The device to load the model onto ('cpu' or 'cuda').
|
||||
|
||||
Returns:
|
||||
(TimeAwareGPT2): The loaded and initialized model.
|
||||
"""
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
print(f"Model config: {config_dict}")
|
||||
|
||||
# Create a config object from the dictionary
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
config = AttrDict(config_dict)
|
||||
|
||||
# Initialize the model with parameters from the config
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(device)
|
||||
|
||||
# Load the saved state dictionary
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
||||
|
||||
Args:
|
||||
dataset (PatientEventDataset): The dataset to retrieve data from.
|
||||
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
||||
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
||||
These tokens will be replaced with -100. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing four tensors:
|
||||
- input_events: (batch_size, sequence_length - 1)
|
||||
- input_tims: (batch_size, sequence_length - 1)
|
||||
- target_events: (batch_size, sequence_length - 1)
|
||||
- target_times: (batch_size, sequence_length - 1)
|
||||
"""
|
||||
batch_data = dataset[batch_slice]
|
||||
|
||||
input_events = [item[0][:-1] for item in batch_data]
|
||||
input_tims = [item[1][:-1] for item in batch_data]
|
||||
target_events = [item[0][1:] for item in batch_data]
|
||||
target_times = [item[1][1:] for item in batch_data]
|
||||
|
||||
input_events = torch.stack(input_events)
|
||||
input_tims = torch.stack(input_tims)
|
||||
target_events = torch.stack(target_events)
|
||||
target_times = torch.stack(target_times)
|
||||
|
||||
return input_events, input_tims, target_events, target_times
|
Reference in New Issue
Block a user