Compare commits

..

23 Commits

Author SHA1 Message Date
93cf2018d2 evaluate_auc: adapt to new utils.load_model signature (infer checkpoint from config) 2025-10-22 11:40:46 +08:00
6801e5bdbb utils.load_model: take only config; infer checkpoint name from config (with legacy fallback) and vocab from checkpoint 2025-10-22 11:39:10 +08:00
92a5bd4a83 evaluate_auc: use new utils.load_model (infer vocab, model variants) and dynamic device; remove unused imports 2025-10-22 11:35:52 +08:00
dfdf64da9a Rewrite load_model to match train.py: support model variants and infer vocab_size from checkpoint; load state dict robustly 2025-10-22 11:30:59 +08:00
bd88daa8c2 update models and training scripts 2025-10-22 08:36:55 +08:00
e348086e52 update 2025-10-21 10:30:18 +08:00
a8aa5a2bd6 update 2025-10-21 09:20:43 +08:00
ddb7dbfc67 update 2025-10-20 16:22:15 +08:00
88cccdad2e feat: Optimize AUC evaluation with parallel processing 2025-10-20 16:16:50 +08:00
8f44018bae update evaluate 2025-10-20 13:47:50 +08:00
1c9e2a2fb3 feat: print model config and add evaluation notebook 2025-10-20 10:14:50 +08:00
6b782b86e1 feat: Add model checkpoints and configurations 2025-10-20 09:38:24 +08:00
9a9de170d1 delete 2025-10-18 22:35:42 +08:00
7e57e5d3b1 refactor: Update survival loss calculation in CombinedLoss 2025-10-18 15:21:10 +08:00
14865ac5b6 Refactor: Remove Jupyter Notebook cell markers 2025-10-18 13:32:26 +08:00
dbc3000192 add evaluation scripts. 2025-10-18 13:26:56 +08:00
082c719975 feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling 2025-10-18 12:42:14 +08:00
a631ac6d59 feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
2025-10-18 11:07:59 +08:00
f7356b183c feat: Add command-line arguments to train.py 2025-10-18 10:23:12 +08:00
3390bc025e feat: Add iteration-based training scripts (single and multi-GPU) 2025-10-18 10:05:37 +08:00
a832a45c62 config: Tune hyperparameters for multi-GPU training
Increase model size (n_embd, n_layer, n_head) for the multi-GPU configuration.

Explicitly set AdamW betas to (0.9, 0.99).
2025-10-17 15:37:42 +08:00
d760c45baf feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
2025-10-17 14:09:34 +08:00
053f86f4da config: Add weight decay to training configuration
Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
2025-10-17 13:47:37 +08:00
14 changed files with 2890 additions and 499 deletions

2
.gitignore vendored
View File

@@ -5,7 +5,7 @@
__pycache__/ __pycache__/
# Model checkpoints # Model checkpoints
best_model_checkpoint.pt *.pt
# Large data files # Large data files
ukb_delphi.txt ukb_delphi.txt

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,18 @@
{
"n_layer": 12,
"n_embd": 120,
"n_head": 12,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

View File

@@ -0,0 +1,18 @@
{
"n_layer": 16,
"n_embd": 256,
"n_head": 16,
"max_epoch": 200,
"batch_size": 128,
"lr_initial": 0.0006,
"lr_final": 6e-05,
"weight_decay": 0.2,
"warmup_epochs": 10,
"early_stopping_patience": 10,
"pdrop": 0.0,
"token_pdrop": 0.0,
"betas": [
0.9,
0.99
]
}

File diff suppressed because it is too large Load Diff

496
evaluate_auc.py Normal file
View File

@@ -0,0 +1,496 @@
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
from joblib import Parallel, delayed
def auc(x1, x2):
n1 = len(x1)
n2 = len(x2)
R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1
U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1
if n1 == 0 or n2 == 0:
return np.nan
return U1 / n1 / n2
def get_common_diseases(delphi_labels, filter_min_total=100):
chapters_of_interest = [
"I. Infectious Diseases",
"II. Neoplasms",
"III. Blood & Immune Disorders",
"IV. Metabolic Diseases",
"V. Mental Disorders",
"VI. Nervous System Diseases",
"VII. Eye Diseases",
"VIII. Ear Diseases",
"IX. Circulatory Diseases",
"X. Respiratory Diseases",
"XI. Digestive Diseases",
"XII. Skin Diseases",
"XIII. Musculoskeletal Diseases",
"XIV. Genitourinary Diseases",
"XV. Pregnancy & Childbirth",
"XVI. Perinatal Conditions",
"XVII. Congenital Abnormalities",
"Death",
]
labels_df = delphi_labels[
delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total)
]
return labels_df["index"].tolist()
def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000):
"""
Computes bootstrapped AUC estimates using PyTorch on CUDA.
Parameters:
case: 1D tensor of scores for positive cases
control: 1D tensor of scores for controls
n_bootstrap: Number of bootstrap replicates
Returns:
Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This function requires a GPU.")
# Convert inputs to CUDA tensors
if not torch.is_tensor(case):
case = torch.tensor(case, device="cuda", dtype=torch.float32)
else:
case = case.to("cuda", dtype=torch.float32)
if not torch.is_tensor(control):
control = torch.tensor(control, device="cuda", dtype=torch.float32)
else:
control = control.to("cuda", dtype=torch.float32)
n_case = case.size(0)
n_control = control.size(0)
total = n_case + n_control
# Generate bootstrap samples
boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda")
boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda")
boot_case = case[boot_idx_case]
boot_control = control[boot_idx_control]
combined = torch.cat([boot_case, boot_control], dim=1)
# Mask to identify case entries
mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda")
mask[:, :n_case] = True
# Compute ranks and AUC
ranks = combined.argsort(dim=1).argsort(dim=1)
case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1)
min_case_rank_sum = n_case * (n_case - 1) / 2.0
U = case_ranks_sum - min_case_rank_sum
aucs = U / (n_case * n_control)
return aucs.cpu().tolist()
# AUC comparison adapted from
# https://github.com/Netflix/vmaf/
def compute_midrank(x):
"""Computes midranks.
Args:
x - a 1D numpy array
Returns:
array of midranks
"""
J = np.argsort(x)
Z = x[J]
N = len(x)
T = np.zeros(N, dtype=np.float32)
i = 0
while i < N:
j = i
while j < N and Z[j] == Z[i]:
j += 1
T[i:j] = 0.5 * (i + j - 1)
i = j
T2 = np.empty(N, dtype=np.float32)
# Note(kazeevn) +1 is due to Python using 0-based indexing
# instead of 1-based in the AUC formula in the paper
T2[J] = T + 1
return T2
def fastDeLong(predictions_sorted_transposed, label_1_count):
"""
The fast version of DeLong's method for computing the covariance of
unadjusted AUC.
Args:
predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
sorted such as the examples with label "1" are first
Returns:
(AUC value, DeLong covariance)
Reference:
@article{sun2014fast,
title={Fast Implementation of DeLong's Algorithm for
Comparing the Areas Under Correlated Receiver Operating Characteristic Curves},
author={Xu Sun and Weichao Xu},
journal={IEEE Signal Processing Letters},
volume={21},
number={11},
pages={1389--1393},
year={2014},
publisher={IEEE}
}
"""
# Short variables are named as they are in the paper
m = label_1_count
n = predictions_sorted_transposed.shape[1] - m
positive_examples = predictions_sorted_transposed[:, :m]
negative_examples = predictions_sorted_transposed[:, m:]
k = predictions_sorted_transposed.shape[0]
tx = np.empty([k, m], dtype=np.float32)
ty = np.empty([k, n], dtype=np.float32)
tz = np.empty([k, m + n], dtype=np.float32)
for r in range(k):
tx[r, :] = compute_midrank(positive_examples[r, :])
ty[r, :] = compute_midrank(negative_examples[r, :])
tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
v01 = (tz[:, :m] - tx[:, :]) / n
v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
sx = np.cov(v01)
sy = np.cov(v10)
delongcov = sx / m + sy / n
return aucs, delongcov
def compute_ground_truth_statistics(ground_truth):
assert np.array_equal(np.unique(ground_truth), [0, 1])
order = (-ground_truth).argsort()
label_1_count = int(ground_truth.sum())
return order, label_1_count
def get_auc_delong_var(healthy_scores, diseased_scores):
"""
Computes ROC AUC value and variance using DeLong's method
Args:
healthy_scores: Values for class 0 (healthy/controls)
diseased_scores: Values for class 1 (diseased/cases)
Returns:
AUC value and variance
"""
# Create ground truth labels (1 for diseased, 0 for healthy)
ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores))
predictions = np.concatenate([diseased_scores, healthy_scores])
# Compute statistics needed for DeLong method
order, label_1_count = compute_ground_truth_statistics(ground_truth)
predictions_sorted_transposed = predictions[np.newaxis, order]
# Calculate AUC and covariance
aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count)
assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
return aucs[0], delongcov
def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False):
age_step = age_groups[1] - age_groups[0]
# Indexes of cases with disease k
wk = np.where(d[2] == k)
if len(wk[0]) < 2:
return None
# For controls, we need to exclude cases with disease k
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
# We need to take into account the offset t and use the tokens for prediction that are at least t before the event
if precomputed_idx is None:
pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1
else:
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
valid_indices = pred_idx != -1
pred_idx = pred_idx[valid_indices]
wall = (wall[0][valid_indices], wall[1][valid_indices])
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
zk = d[3][wall] # Target times
x = p[..., j][(wall[0], pred_idx)]
p_idx = wall[0]
out = []
for i, aa in enumerate(age_groups):
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
if not np.any(a):
continue
selected_groups = p_idx[a]
_, unique_indices = np.unique(selected_groups, return_index=True)
a_filtered = a[a]
a_filtered[:] = False
a_filtered[unique_indices] = True
a[a] = a_filtered
is_case = np.zeros_like(x, dtype=bool)
is_case[:len(wk[0])] = True
control = x[~is_case & a]
case = x[is_case & a]
if len(control) == 0 or len(case) == 0:
continue
if use_delong:
auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case)
auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong}
else:
auc_delong_dict = {}
if n_bootstrap > 1:
aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap)
for bootstrap_idx in range(n_bootstrap):
y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx]
out_item = {
"token": k,
"auc": y,
"age": aa,
"n_healthy": len(control),
"n_diseased": len(case),
}
out.append(out_item | auc_delong_dict)
if n_bootstrap > 1:
out_item["bootstrap_idx"] = bootstrap_idx
return out
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
all_aucs = []
for sex, sex_idx in [("female", 2), ("male", 3)]:
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
p_sex = p100k[sex_mask]
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
for j, k in tqdm(
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
):
out = get_calibration_auc(
j,
k,
d100k_sex,
p_sex,
age_groups=age_groups,
offset=offset,
precomputed_idx=precomputed_idx_subset,
n_bootstrap=n_bootstrap,
use_delong=True,
)
if out is None:
continue
for out_item in out:
out_item["sex"] = sex
all_aucs.append(out_item)
return all_aucs
# New internal function that performs the AUC evaluation pipeline.
def evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=100,
disease_chunk_size=200,
age_groups=np.arange(40, 80, 5),
offset=0.1,
batch_size=256,
device="cpu",
seed=1337,
n_bootstrap=1,
meta_info={},
n_jobs=-1,
):
"""
Runs the AUC evaluation pipeline.
Args:
model (torch.nn.Module): The loaded model set to eval().
d100k (tuple): Data batch from get_batch.
delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv").
output_path (str | None): Directory where CSV files will be written. If None, files will not be saved.
diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used.
filter_min_total (int): Minimum total token count to include a token.
disease_chunk_size (int): Maximum chunk size for processing diseases.
age_groups (np.ndarray): Age groups to use in calibration.
offset (float): Offset used in get_calibration_auc.
batch_size (int): Batch size for model forwarding.
device (str): Device identifier.
seed (int): Random seed for reproducibility.
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
n_jobs (int): Number of parallel jobs to run.
Returns:
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
"""
assert n_bootstrap > 0, "n_bootstrap must be greater than 0"
# Set random seeds
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Get common diseases
if diseases_of_interest is None:
diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total)
# Split diseases into chunks for processing
num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size
diseases_chunks = np.array_split(diseases_of_interest, num_chunks)
# Precompute prediction indices for calibration
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
p100k = []
model.to(device)
with torch.no_grad():
for dd in tqdm(
zip(*[torch.split(x, batch_size) for x in d100k]),
desc=f"Model inference",
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
p100k.append(outputs.astype("float16"))
p100k = np.vstack(p100k)
results = Parallel(n_jobs=n_jobs)(
delayed(process_chunk)(
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
)
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
)
all_aucs = [item for sublist in results for item in sublist]
df_auc_unpooled = pd.DataFrame(all_aucs)
for key, value in meta_info.items():
df_auc_unpooled[key] = value
delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']]
df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner")
def aggregate_age_brackets_delong(group):
# For normal distributions, when averaging n of them:
# The variance of the sum is the sum of variances
# The variance of the average is the sum of variances divided by n^2
n = len(group)
mean = group['auc_delong'].mean()
# Since we're taking the average, divide combined variance by n^2
var = group['auc_variance_delong'].sum() / (n**2)
return pd.Series({
'auc': mean,
'auc_variance_delong': var,
'n_samples': n,
'n_diseased': group['n_diseased'].sum(),
'n_healthy': group['n_healthy'].sum(),
})
print('Using DeLong method to calculate AUC confidence intervals..')
df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index()
df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner")
if output_path is not None:
Path(output_path).mkdir(exist_ok=True, parents=True)
df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False)
df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False)
return df_auc_unpooled_merged, df_auc_merged
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction")
# Optional filtering/chunking parameters:
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
args = parser.parse_args()
model_name = args.model_name
output_path = f'auc_evaluation_{model_name}'
dataset_subset_size = args.dataset_subset_size
offset = args.offset
# Create output folder if it doesn't exist.
Path(output_path).mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 1337
# Load model checkpoint and initialize model.
model = load_model(
config_path=f'config_{model_name}.json',
device=device,
)
model.eval()
model = model.to(device)
# Load training and validation data.
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1:
dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation.
d100k = get_batch(val_dataset, slice(dataset_subset_size))
# Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
# Call the internal evaluation function.
df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline(
model,
d100k,
output_path,
delphi_labels,
diseases_of_interest=None,
filter_min_total=args.filter_min_total,
disease_chunk_size=args.disease_chunk_size,
device=device,
seed=seed,
offset=offset,
n_bootstrap=args.n_bootstrap,
n_jobs=args.n_jobs,
)
if __name__ == "__main__":
main()

552
evaluate_models.ipynb Normal file

File diff suppressed because one or more lines are too long

183
models.py
View File

@@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple from typing import Tuple, Optional
# ============================================================================= # =============================================================================
# 1. Component Modules (Building Blocks) # 1. Component Modules (Building Blocks)
@@ -85,6 +85,39 @@ class AgeSinusoidalEncoding(nn.Module):
output[:, :, 1::2] = torch.sin(args) output[:, :, 1::2] = torch.sin(args)
return output return output
class LearnableAgeEncoding(nn.Module):
"""Combines fixed sinusoidal age encodings with a learnable MLP projection."""
def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
self.base_dim = base_dim
self.final_dim = final_dim or base_dim
hidden_dim = hidden_dim or base_dim
if hidden_dim <= 0:
raise ValueError("hidden_dim must be a positive integer.")
if self.final_dim <= 0:
raise ValueError("final_dim must be a positive integer.")
self.sinusoidal = AgeSinusoidalEncoding(base_dim)
mlp_layers = [
nn.Linear(base_dim, hidden_dim),
nn.GELU(),
]
if dropout > 0.0:
mlp_layers.append(nn.Dropout(dropout))
mlp_layers.append(nn.Linear(hidden_dim, self.final_dim))
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, t: torch.Tensor) -> torch.Tensor:
sin_embed = self.sinusoidal(t)
flat_embed = sin_embed.reshape(-1, self.base_dim)
projected = self.mlp(flat_embed)
return projected.reshape(*sin_embed.shape[:-1], self.final_dim)
class PiecewiseLinearEncoder(nn.Module): class PiecewiseLinearEncoder(nn.Module):
""" """
Encodes continuous variables using piecewise linear encoding. Encodes continuous variables using piecewise linear encoding.
@@ -177,9 +210,10 @@ class TimeAwareGPT2(nn.Module):
A time-aware GPT-2 model with custom temporal features. A time-aware GPT-2 model with custom temporal features.
""" """
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
super().__init__() super().__init__()
self.token_pdrop = token_pdrop self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.wte = nn.Embedding(vocab_size, n_embd) self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd)
@@ -234,94 +268,71 @@ class TimeAwareGPT2(nn.Module):
""" """
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
class CovariateAwareGPT2(nn.Module): @torch.no_grad()
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
""" """
Extends TimeAwareGPT2 to incorporate static and time-varying covariates. Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
""" """
self.eval()
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, if termination_tokens is None:
pdrop: float, token_pdrop: float, num_bins: int): termination_tokens = [1269]
"""
Initializes the CovariateAwareGPT2 model.
Args: termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
vocab_size (int): Size of the event vocabulary. mask_time = -10000
n_embd (int): Embedding dimensionality.
n_layer (int): Number of transformer layers.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability for layers.
token_pdrop (float): Dropout probability for input token embeddings.
num_bins (int): Number of bins for the PiecewiseLinearEncoder.
"""
super().__init__()
self.token_pdrop = token_pdrop
self.wte = nn.Embedding(vocab_size, n_embd) for _ in range(max_new_tokens):
self.age_encoder = AgeSinusoidalEncoding(n_embd) logits = self(x, t)
self.drop = nn.Dropout(pdrop) logits = logits[:, -1, :]
self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)])
self.n_embd = n_embd
self.cov_encoder = PiecewiseLinearEncoder(num_bins=num_bins, embedding_dim=n_embd)
self.ln_f = nn.LayerNorm(2 * n_embd) if self.ignore_tokens:
self.head = nn.Sequential( logits[:, self.ignore_tokens] = -torch.inf
nn.Linear(2 * n_embd, n_embd),
nn.GELU(), if no_repeat:
nn.Linear(n_embd, vocab_size) fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
return x, t, final_logits
class TimeAwareGPT2Learnable(TimeAwareGPT2):
"""Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.age_encoder = LearnableAgeEncoding(
base_dim=self.n_embd,
hidden_dim=2 * self.n_embd,
final_dim=self.n_embd,
) )
def forward(self, x: torch.Tensor, t: torch.Tensor, cov: torch.Tensor, cov_t: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CovariateAwareGPT2 model.
Args:
x (torch.Tensor): Event sequence tensor of shape (B, L).
t (torch.Tensor): Time sequence tensor of shape (B, L).
cov (torch.Tensor): Covariate tensor of shape (B, N).
cov_t (torch.Tensor): Covariate time tensor of shape (B).
Returns:
torch.Tensor: Logits of shape (B, L, vocab_size).
"""
B, L = x.size()
cov_encoded = self.cov_encoder(cov).sum(dim=1).unsqueeze(1)
cov_t_encoded = self.age_encoder(t - cov_t.unsqueeze(1))
cov_embed = cov_encoded + cov_t_encoded
token_embeddings = self.wte(x)
if self.training and self.token_pdrop > 0:
drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop
token_embeddings[drop_mask] = 0.0
pos_embeddings = self.age_encoder(t.float())
seq_embed = self.drop(token_embeddings + pos_embeddings)
t_i = t.unsqueeze(-1)
t_j = t.unsqueeze(1)
time_mask = (t_j < t_i)
padding_mask = (x != 0).unsqueeze(1)
combined_mask = time_mask & padding_mask
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (x != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
block_output = seq_embed
for block in self.blocks:
block_output = block(block_output, custom_mask=combined_mask)
integrated_embed = torch.cat([block_output, cov_embed], dim=-1)
final_output = self.ln_f(integrated_embed)
logits = self.head(final_output)
return logits
def get_num_params(self) -> float:
"""
Returns the number of trainable parameters in the model in millions.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
# ============================================================================= # =============================================================================
# 3. Loss Function # 3. Loss Function
@@ -367,8 +378,12 @@ class CombinedLoss(nn.Module):
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
loss_ce = per_element_ce[mask].mean() loss_ce = per_element_ce[mask].mean()
intensity = torch.sum(torch.exp(logits), dim=2) # Survival loss based on exponential log-likelihood
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) t_min = 0.1
loss_survival = per_element_survival[mask].mean() lse = torch.logsumexp(logits, dim=-1)
lse = -torch.log(torch.exp(-lse) + t_min)
ldt = -torch.log(t + t_min)
loss_dt = -(lse - torch.exp(lse - ldt))
loss_survival = loss_dt[mask].mean()
return loss_ce, loss_survival return loss_ce, loss_survival

View File

@@ -2,3 +2,4 @@ torch
numpy numpy
tqdm tqdm
matplotlib matplotlib
joblib

View File

@@ -6,8 +6,10 @@ import numpy as np
import math import math
import tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import json
import argparse
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
# --- Configuration --- # --- Configuration ---
@@ -23,14 +25,17 @@ class TrainConfig:
n_head = 12 n_head = 12
pdrop = 0.1 pdrop = 0.1
token_pdrop = 0.1 token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters # Training parameters
max_epoch = 200 max_epoch = 200
batch_size = 128 batch_size = 128
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 10 early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
@@ -41,10 +46,50 @@ class TrainConfig:
# --- Main Training Script --- # --- Main Training Script ---
def main(): def main():
config = TrainConfig() parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.')
parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.')
parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.')
parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.')
parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.')
parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.')
parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.')
parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.')
parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.')
parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.')
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.')
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2', help='Model architecture to train.')
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" args = parser.parse_args()
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
config = TrainConfig()
config.n_layer = args.n_layer
config.n_embd = args.n_embd
config.n_head = args.n_head
config.max_epoch = args.max_epoch
config.batch_size = args.batch_size
config.lr_initial = args.lr_initial
config.lr_final = args.lr_final
config.weight_decay = args.weight_decay
config.warmup_epochs = args.warmup_epochs
config.early_stopping_patience = args.early_stopping_patience
config.pdrop = args.pdrop
config.token_pdrop = args.token_pdrop
config.betas = tuple(args.betas)
config.model_name = args.model
model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}"
model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading --- # --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
@@ -63,7 +108,12 @@ def main():
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}[config.model_name]
model = model_cls(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
n_layer=config.n_layer, n_layer=config.n_layer,
@@ -75,7 +125,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
@@ -193,7 +243,7 @@ def main():
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file --- # --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f: with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)): for i in range(len(train_losses_total)):

View File

@@ -1,400 +0,0 @@
# train.py (DDP-ready)
import os
import math
import argparse
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
# Model parameters
n_embd = 256
n_layer = 8
n_head = 8
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1]
# System parameters (device 将在 main() 内按 local_rank 动态设置)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def setup_distributed(backend: str = "nccl"):
"""
如果由 torchrun 启动且 WORLD_SIZE>1则初始化分布式。
返回 (is_distributed, world_size, rank, local_rank)
"""
world_size = int(os.environ.get("WORLD_SIZE", "1"))
is_distributed = world_size > 1
if is_distributed:
if not dist.is_initialized():
dist.init_process_group(backend=backend, init_method="env://")
rank = dist.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
else:
rank = 0
local_rank = 0
return is_distributed, world_size, rank, local_rank
def cleanup_distributed():
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def all_reduce_mean(value: float, device, world_size: int):
"""
value 是 Python float本进程的和/均值),返回所有进程平均后的 float。
"""
tensor = torch.tensor([value], dtype=torch.float32, device=device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor /= world_size
return float(tensor.item())
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 分布式初始化
is_dist, world_size, rank, local_rank = setup_distributed(args.backend)
# 基本环境
torch.manual_seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.backends.cudnn.benchmark = True
config = TrainConfig()
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
config.device = device
is_main = (rank == 0)
# --- 1. Data Loading ---
if is_main:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# 分布式采样器
if is_dist:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
else:
train_sampler = None
val_sampler = None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
persistent_workers=True if 4 > 0 else False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
persistent_workers=True if 4 > 0 else False,
)
# --- 2. Model, Optimizer, and Loss Initialization ---
if is_main:
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
if is_main and hasattr(model, "get_num_params"):
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
# DDP 包装
if is_dist:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank] if device.type == "cuda" else None,
output_device=local_rank if device.type == "cuda" else None,
find_unused_parameters=False,
)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
# 只在主进程收集与画图
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main:
print("Starting training...")
stop_training = False
for epoch in range(config.max_epoch):
# 设置 epoch 给分布式采样器,确保跨 epoch shuffle
if is_dist:
train_sampler.set_epoch(epoch)
# --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Phase ---
if is_main:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # 非主进程禁用 tqdm
model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
for batch in pbar:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
train_loss_ce_acc += float(loss_ce.item())
train_loss_surv_acc += float(loss_survival.item())
train_steps += 1
if is_main and isinstance(pbar, tqdm.tqdm):
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
# 进程内均值
avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1)
avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1)
# 所有进程平均
if is_dist:
avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size)
avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size)
else:
avg_train_loss_ce = avg_train_loss_ce_local
avg_train_loss_surv = avg_train_loss_surv_local
if is_main:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
# --- Validation Phase ---
if is_main:
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
else:
pbar_val = val_loader
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
for batch in pbar_val:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += float(loss_ce.item())
val_loss_surv_acc += float(loss_survival.item())
val_steps += 1
if is_main and isinstance(pbar_val, tqdm.tqdm):
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1)
avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1)
if is_dist:
avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size)
avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size)
else:
avg_val_loss_ce = avg_val_loss_ce_local
avg_val_loss_surv = avg_val_loss_surv_local
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 主进程打印与记录
if is_main:
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
# --- Early Stopping Check (基于聚合后的 total_val_loss) ---
improved = False
if is_main:
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
improved = True
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# DDP: 保存 module.state_dict()
state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict()
torch.save(state_dict, 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
stop_training = patience_counter >= config.early_stopping_patience
# 把 improved/stop 广播到所有进程,确保一致退出
if is_dist:
flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32)
dist.broadcast(flag_tensor, src=0)
stop_training = bool(int(flag_tensor.item()))
if stop_training:
if is_main:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
# --- Save Best Model at the End (只主进程) ---
if is_main:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# 为了易用,这里在主进程上重新构建单卡模型加载权重再保存
model_single = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to('cpu')
model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
print("Saving final best model to best_model.pt")
torch.save(model_single.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
if num_epochs > 0:
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
# 清理分布式
cleanup_distributed()
if __name__ == '__main__':
main()

218
train_iter.py Normal file
View File

@@ -0,0 +1,218 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.0
token_pdrop = 0.0
# Training parameters
max_iter = 200000
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
# --- 3. Training Loop ---
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
print("\nTraining finished.")
# --- 4. Final Validation ---
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter.png')
print("\nLoss curves saved to loss_curves_iter.png")
if __name__ == '__main__':
main()

162
utils.py
View File

@@ -1,7 +1,11 @@
import os
import torch import torch
import numpy as np import numpy as np
import random import random
from collections import defaultdict from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
class PatientEventDataset(torch.utils.data.Dataset): class PatientEventDataset(torch.utils.data.Dataset):
""" """
@@ -39,17 +43,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
""" """
return len(self.patient_ids) return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx):
""" """
Retrieves, processes, and returns a single patient's event sequence. Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args: Args:
idx (int): The index of the patient to retrieve. idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns: Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence), If idx is an int, a tuple of two torch.long tensors:
both of shape (block_length,). (event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
""" """
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort # 1. Retrieve and Sort
patient_id = self.patient_ids[idx] patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -102,3 +111,146 @@ class PatientEventDataset(torch.utils.data.Dataset):
time_tensor = torch.tensor(time_stamps, dtype=torch.long) time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor return event_tensor, time_tensor
def load_model(config_path: str, device: str = 'cpu'):
"""
Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
device: 'cpu' or 'cuda'.
Returns:
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# Newer naming (includes model_name) used by train.py when model_name is present
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else:
candidate = None
for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate
# 5) Build model from config (tolerant to missing fields)
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# 6) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
model.eval()
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times