Compare commits
	
		
			47 Commits
		
	
	
		
			c7296381b8
			...
			main
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 02be03f784 | |||
| 4d1fc63667 | |||
| dd58ced9b9 | |||
| 3bef72f50b | |||
| a81da36657 | |||
| b954b4b3e7 | |||
| f8e0104d6b | |||
| 262a7db0da | |||
| 9917b3ab63 | |||
| 8316326d7e | |||
| 6dd5eb95c7 | |||
| 5b0642eb6e | |||
| 93cf2018d2 | |||
| 6801e5bdbb | |||
| 92a5bd4a83 | |||
| dfdf64da9a | |||
| bd88daa8c2 | |||
| e348086e52 | |||
| a8aa5a2bd6 | |||
| ddb7dbfc67 | |||
| 88cccdad2e | |||
| 8f44018bae | |||
| 1c9e2a2fb3 | |||
| 6b782b86e1 | |||
| 9a9de170d1 | |||
| 7e57e5d3b1 | |||
| 14865ac5b6 | |||
| dbc3000192 | |||
| 082c719975 | |||
| a631ac6d59 | |||
| f7356b183c | |||
| 3390bc025e | |||
| a832a45c62 | |||
| d760c45baf | |||
| 053f86f4da | |||
| d4d25ac9c7 | |||
| fe0304a96a | |||
| 7e8d8d307b | |||
| fc0aef4e71 | |||
| 02d84a7eca | |||
| cb7575a229 | |||
| e2495f43b0 | |||
| 6e0713048a | |||
| eec406d79f | |||
| e3e533c9ec | |||
| b5172392cb | |||
| 6b0b86d9d0 | 
							
								
								
									
										17
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | # IDE settings | ||||||
|  | .idea/ | ||||||
|  |  | ||||||
|  | # Python cache | ||||||
|  | __pycache__/ | ||||||
|  |  | ||||||
|  | # Model checkpoints | ||||||
|  | *.pt | ||||||
|  |  | ||||||
|  | # Large data files | ||||||
|  | ukb_delphi.txt | ||||||
|  | ukb_real.bin | ||||||
|  |  | ||||||
|  | # Small data files | ||||||
|  | fields.txt | ||||||
|  | icd10_codes_mod.tsv | ||||||
|  | labels.csv | ||||||
							
								
								
									
										
											BIN
										
									
								
								best_model_n_embd_120_n_layer_12_n_head_12.pt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								best_model_n_embd_120_n_layer_12_n_head_12.pt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								best_model_n_embd_256_n_layer_16_n_head_16.pt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								best_model_n_embd_256_n_layer_16_n_head_16.pt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										19
									
								
								config_n_embd_120_n_layer_12_n_head_12.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								config_n_embd_120_n_layer_12_n_head_12.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | { | ||||||
|  |     "model_name": "TimeAwareGPT2", | ||||||
|  |     "n_layer": 12, | ||||||
|  |     "n_embd": 120, | ||||||
|  |     "n_head": 12, | ||||||
|  |     "max_epoch": 200, | ||||||
|  |     "batch_size": 128, | ||||||
|  |     "lr_initial": 0.0006, | ||||||
|  |     "lr_final": 6e-05, | ||||||
|  |     "weight_decay": 0.2, | ||||||
|  |     "warmup_epochs": 10, | ||||||
|  |     "early_stopping_patience": 10, | ||||||
|  |     "pdrop": 0.0, | ||||||
|  |     "token_pdrop": 0.0, | ||||||
|  |     "betas": [ | ||||||
|  |         0.9, | ||||||
|  |         0.99 | ||||||
|  |     ] | ||||||
|  | } | ||||||
							
								
								
									
										19
									
								
								config_n_embd_256_n_layer_16_n_head_16.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								config_n_embd_256_n_layer_16_n_head_16.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | { | ||||||
|  |     "model_name": "TimeAwareGPT2", | ||||||
|  |     "n_layer": 16, | ||||||
|  |     "n_embd": 256, | ||||||
|  |     "n_head": 16, | ||||||
|  |     "max_epoch": 200, | ||||||
|  |     "batch_size": 128, | ||||||
|  |     "lr_initial": 0.0006, | ||||||
|  |     "lr_final": 6e-05, | ||||||
|  |     "weight_decay": 0.2, | ||||||
|  |     "warmup_epochs": 10, | ||||||
|  |     "early_stopping_patience": 10, | ||||||
|  |     "pdrop": 0.0, | ||||||
|  |     "token_pdrop": 0.0, | ||||||
|  |     "betas": [ | ||||||
|  |         0.9, | ||||||
|  |         0.99 | ||||||
|  |     ] | ||||||
|  | } | ||||||
							
								
								
									
										1271
									
								
								delphi_labels_chapters_colours_icd.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1271
									
								
								delphi_labels_chapters_colours_icd.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										496
									
								
								evaluate_auc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										496
									
								
								evaluate_auc.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,496 @@ | |||||||
|  | import torch | ||||||
|  | from tqdm import tqdm | ||||||
|  | import pandas as pd | ||||||
|  | import numpy as np | ||||||
|  | import argparse | ||||||
|  | from utils import load_model, get_batch, PatientEventDataset | ||||||
|  | from pathlib import Path | ||||||
|  | from joblib import Parallel, delayed | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def auc(x1, x2): | ||||||
|  |     n1 = len(x1) | ||||||
|  |     n2 = len(x2) | ||||||
|  |     R1 = np.concatenate([x1, x2]).argsort().argsort()[:n1].sum() + n1 | ||||||
|  |     U1 = n1 * n2 + 0.5 * n1 * (n1 + 1) - R1 | ||||||
|  |     if n1 == 0 or n2 == 0: | ||||||
|  |         return np.nan | ||||||
|  |     return U1 / n1 / n2 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_common_diseases(delphi_labels, filter_min_total=100): | ||||||
|  |     chapters_of_interest = [ | ||||||
|  |         "I. Infectious Diseases", | ||||||
|  |         "II. Neoplasms", | ||||||
|  |         "III. Blood & Immune Disorders", | ||||||
|  |         "IV. Metabolic Diseases", | ||||||
|  |         "V. Mental Disorders", | ||||||
|  |         "VI. Nervous System Diseases", | ||||||
|  |         "VII. Eye Diseases", | ||||||
|  |         "VIII. Ear Diseases", | ||||||
|  |         "IX. Circulatory Diseases", | ||||||
|  |         "X. Respiratory Diseases", | ||||||
|  |         "XI. Digestive Diseases", | ||||||
|  |         "XII. Skin Diseases", | ||||||
|  |         "XIII. Musculoskeletal Diseases", | ||||||
|  |         "XIV. Genitourinary Diseases", | ||||||
|  |         "XV. Pregnancy & Childbirth", | ||||||
|  |         "XVI. Perinatal Conditions", | ||||||
|  |         "XVII. Congenital Abnormalities", | ||||||
|  |         "Death", | ||||||
|  |     ] | ||||||
|  |     labels_df = delphi_labels[ | ||||||
|  |         delphi_labels["ICD-10 Chapter (short)"].isin(chapters_of_interest) * (delphi_labels["count"] > filter_min_total) | ||||||
|  |     ] | ||||||
|  |     return labels_df["index"].tolist() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def optimized_bootstrapped_auc_gpu(case, control, n_bootstrap=1000): | ||||||
|  |     """ | ||||||
|  |     Computes bootstrapped AUC estimates using PyTorch on CUDA. | ||||||
|  |  | ||||||
|  |     Parameters: | ||||||
|  |         case: 1D tensor of scores for positive cases | ||||||
|  |         control: 1D tensor of scores for controls | ||||||
|  |         n_bootstrap: Number of bootstrap replicates | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         Tensor of shape (n_bootstrap,) containing AUC for each bootstrap replicate | ||||||
|  |     """ | ||||||
|  |     if not torch.cuda.is_available(): | ||||||
|  |         raise RuntimeError("CUDA is not available. This function requires a GPU.") | ||||||
|  |  | ||||||
|  |     # Convert inputs to CUDA tensors | ||||||
|  |     if not torch.is_tensor(case): | ||||||
|  |         case = torch.tensor(case, device="cuda", dtype=torch.float32) | ||||||
|  |     else: | ||||||
|  |         case = case.to("cuda", dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     if not torch.is_tensor(control): | ||||||
|  |         control = torch.tensor(control, device="cuda", dtype=torch.float32) | ||||||
|  |     else: | ||||||
|  |         control = control.to("cuda", dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     n_case = case.size(0) | ||||||
|  |     n_control = control.size(0) | ||||||
|  |     total = n_case + n_control | ||||||
|  |  | ||||||
|  |     # Generate bootstrap samples | ||||||
|  |     boot_idx_case = torch.randint(0, n_case, (n_bootstrap, n_case), device="cuda") | ||||||
|  |     boot_idx_control = torch.randint(0, n_control, (n_bootstrap, n_control), device="cuda") | ||||||
|  |  | ||||||
|  |     boot_case = case[boot_idx_case] | ||||||
|  |     boot_control = control[boot_idx_control] | ||||||
|  |  | ||||||
|  |     combined = torch.cat([boot_case, boot_control], dim=1) | ||||||
|  |  | ||||||
|  |     # Mask to identify case entries | ||||||
|  |     mask = torch.zeros((n_bootstrap, total), dtype=torch.bool, device="cuda") | ||||||
|  |     mask[:, :n_case] = True | ||||||
|  |  | ||||||
|  |     # Compute ranks and AUC | ||||||
|  |     ranks = combined.argsort(dim=1).argsort(dim=1) | ||||||
|  |     case_ranks_sum = torch.sum(ranks.float() * mask.float(), dim=1) | ||||||
|  |     min_case_rank_sum = n_case * (n_case - 1) / 2.0 | ||||||
|  |     U = case_ranks_sum - min_case_rank_sum | ||||||
|  |     aucs = U / (n_case * n_control) | ||||||
|  |     return aucs.cpu().tolist() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # AUC comparison adapted from | ||||||
|  | # https://github.com/Netflix/vmaf/ | ||||||
|  | def compute_midrank(x): | ||||||
|  |     """Computes midranks. | ||||||
|  |     Args: | ||||||
|  |        x - a 1D numpy array | ||||||
|  |     Returns: | ||||||
|  |        array of midranks | ||||||
|  |     """ | ||||||
|  |     J = np.argsort(x) | ||||||
|  |     Z = x[J] | ||||||
|  |     N = len(x) | ||||||
|  |     T = np.zeros(N, dtype=np.float32) | ||||||
|  |     i = 0 | ||||||
|  |     while i < N: | ||||||
|  |         j = i | ||||||
|  |         while j < N and Z[j] == Z[i]: | ||||||
|  |             j += 1 | ||||||
|  |         T[i:j] = 0.5 * (i + j - 1) | ||||||
|  |         i = j | ||||||
|  |     T2 = np.empty(N, dtype=np.float32) | ||||||
|  |     # Note(kazeevn) +1 is due to Python using 0-based indexing | ||||||
|  |     # instead of 1-based in the AUC formula in the paper | ||||||
|  |     T2[J] = T + 1 | ||||||
|  |     return T2 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def fastDeLong(predictions_sorted_transposed, label_1_count): | ||||||
|  |     """ | ||||||
|  |     The fast version of DeLong's method for computing the covariance of | ||||||
|  |     unadjusted AUC. | ||||||
|  |     Args: | ||||||
|  |        predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples] | ||||||
|  |           sorted such as the examples with label "1" are first | ||||||
|  |     Returns: | ||||||
|  |        (AUC value, DeLong covariance) | ||||||
|  |     Reference: | ||||||
|  |      @article{sun2014fast, | ||||||
|  |        title={Fast Implementation of DeLong's Algorithm for | ||||||
|  |               Comparing the Areas Under Correlated Receiver Operating Characteristic Curves}, | ||||||
|  |        author={Xu Sun and Weichao Xu}, | ||||||
|  |        journal={IEEE Signal Processing Letters}, | ||||||
|  |        volume={21}, | ||||||
|  |        number={11}, | ||||||
|  |        pages={1389--1393}, | ||||||
|  |        year={2014}, | ||||||
|  |        publisher={IEEE} | ||||||
|  |      } | ||||||
|  |     """ | ||||||
|  |     # Short variables are named as they are in the paper | ||||||
|  |     m = label_1_count | ||||||
|  |     n = predictions_sorted_transposed.shape[1] - m | ||||||
|  |     positive_examples = predictions_sorted_transposed[:, :m] | ||||||
|  |     negative_examples = predictions_sorted_transposed[:, m:] | ||||||
|  |     k = predictions_sorted_transposed.shape[0] | ||||||
|  |  | ||||||
|  |     tx = np.empty([k, m], dtype=np.float32) | ||||||
|  |     ty = np.empty([k, n], dtype=np.float32) | ||||||
|  |     tz = np.empty([k, m + n], dtype=np.float32) | ||||||
|  |     for r in range(k): | ||||||
|  |         tx[r, :] = compute_midrank(positive_examples[r, :]) | ||||||
|  |         ty[r, :] = compute_midrank(negative_examples[r, :]) | ||||||
|  |         tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :]) | ||||||
|  |     aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n | ||||||
|  |     v01 = (tz[:, :m] - tx[:, :]) / n | ||||||
|  |     v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m | ||||||
|  |     sx = np.cov(v01) | ||||||
|  |     sy = np.cov(v10) | ||||||
|  |     delongcov = sx / m + sy / n | ||||||
|  |     return aucs, delongcov | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def compute_ground_truth_statistics(ground_truth): | ||||||
|  |     assert np.array_equal(np.unique(ground_truth), [0, 1]) | ||||||
|  |     order = (-ground_truth).argsort() | ||||||
|  |     label_1_count = int(ground_truth.sum()) | ||||||
|  |     return order, label_1_count | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_auc_delong_var(healthy_scores, diseased_scores): | ||||||
|  |     """ | ||||||
|  |     Computes ROC AUC value and variance using DeLong's method | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         healthy_scores: Values for class 0 (healthy/controls) | ||||||
|  |         diseased_scores: Values for class 1 (diseased/cases) | ||||||
|  |     Returns: | ||||||
|  |         AUC value and variance | ||||||
|  |     """ | ||||||
|  |     # Create ground truth labels (1 for diseased, 0 for healthy) | ||||||
|  |     ground_truth = np.array([1] * len(diseased_scores) + [0] * len(healthy_scores)) | ||||||
|  |     predictions = np.concatenate([diseased_scores, healthy_scores]) | ||||||
|  |  | ||||||
|  |     # Compute statistics needed for DeLong method | ||||||
|  |     order, label_1_count = compute_ground_truth_statistics(ground_truth) | ||||||
|  |     predictions_sorted_transposed = predictions[np.newaxis, order] | ||||||
|  |  | ||||||
|  |     # Calculate AUC and covariance | ||||||
|  |     aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count) | ||||||
|  |     assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers" | ||||||
|  |  | ||||||
|  |     return aucs[0], delongcov | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), precomputed_idx=None, n_bootstrap=1, use_delong=False): | ||||||
|  |     age_step = age_groups[1] - age_groups[0] | ||||||
|  |  | ||||||
|  |     # Indexes of cases with disease k | ||||||
|  |     wk = np.where(d[2] == k) | ||||||
|  |  | ||||||
|  |     if len(wk[0]) < 2: | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     # For controls, we need to exclude cases with disease k | ||||||
|  |     wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None]) | ||||||
|  |  | ||||||
|  |     wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]]))  # All cases and controls | ||||||
|  |  | ||||||
|  |     # We need to take into account the offset t and use the tokens for prediction that are at least t before the event | ||||||
|  |     if precomputed_idx is None: | ||||||
|  |         pred_idx = (d[1][wall[0]] <= d[3][wall].reshape(-1, 1) - offset).sum(1) - 1 | ||||||
|  |     else: | ||||||
|  |         pred_idx = precomputed_idx[wall]  # It's actually much faster to precompute this | ||||||
|  |  | ||||||
|  |     valid_indices = pred_idx != -1 | ||||||
|  |     pred_idx = pred_idx[valid_indices] | ||||||
|  |     wall = (wall[0][valid_indices], wall[1][valid_indices]) | ||||||
|  |      | ||||||
|  |     z = d[1][(wall[0], pred_idx)]  # Times of the tokens for prediction | ||||||
|  |     zk = d[3][wall]  # Target times | ||||||
|  |  | ||||||
|  |     x = p[..., j][(wall[0], pred_idx)] | ||||||
|  |  | ||||||
|  |     p_idx = wall[0] | ||||||
|  |  | ||||||
|  |     out = [] | ||||||
|  |  | ||||||
|  |     for i, aa in enumerate(age_groups): | ||||||
|  |         a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step) | ||||||
|  |          | ||||||
|  |         if not np.any(a): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         selected_groups = p_idx[a] | ||||||
|  |         _, unique_indices = np.unique(selected_groups, return_index=True) | ||||||
|  |          | ||||||
|  |         a_filtered = a[a] | ||||||
|  |         a_filtered[:] = False | ||||||
|  |         a_filtered[unique_indices] = True | ||||||
|  |         a[a] = a_filtered | ||||||
|  |  | ||||||
|  |         is_case = np.zeros_like(x, dtype=bool) | ||||||
|  |         is_case[:len(wk[0])] = True | ||||||
|  |          | ||||||
|  |         control = x[~is_case & a] | ||||||
|  |         case = x[is_case & a] | ||||||
|  |  | ||||||
|  |         if len(control) == 0 or len(case) == 0: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         if use_delong: | ||||||
|  |             auc_value_delong, auc_variance_delong = get_auc_delong_var(control, case) | ||||||
|  |             auc_delong_dict = {"auc_delong": auc_value_delong, "auc_variance_delong": auc_variance_delong} | ||||||
|  |         else: | ||||||
|  |             auc_delong_dict = {} | ||||||
|  |  | ||||||
|  |         if n_bootstrap > 1: | ||||||
|  |             aucs_bootstrapped = optimized_bootstrapped_auc_gpu(case, control, n_bootstrap) | ||||||
|  |  | ||||||
|  |         for bootstrap_idx in range(n_bootstrap): | ||||||
|  |             y = auc_value_delong if n_bootstrap == 1 else aucs_bootstrapped[bootstrap_idx] | ||||||
|  |             out_item = { | ||||||
|  |                 "token": k, | ||||||
|  |                 "auc": y, | ||||||
|  |                 "age": aa, | ||||||
|  |                 "n_healthy": len(control), | ||||||
|  |                 "n_diseased": len(case), | ||||||
|  |             } | ||||||
|  |             out.append(out_item | auc_delong_dict) | ||||||
|  |             if n_bootstrap > 1: | ||||||
|  |                 out_item["bootstrap_idx"] = bootstrap_idx | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap): | ||||||
|  |     all_aucs = [] | ||||||
|  |     for sex, sex_idx in [("female", 2), ("male", 3)]: | ||||||
|  |         sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy() | ||||||
|  |         p_sex = p100k[sex_mask] | ||||||
|  |         d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k] | ||||||
|  |         precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy() | ||||||
|  |         for j, k in tqdm( | ||||||
|  |             list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}" | ||||||
|  |         ): | ||||||
|  |             out = get_calibration_auc( | ||||||
|  |                 j, | ||||||
|  |                 k, | ||||||
|  |                 d100k_sex, | ||||||
|  |                 p_sex, | ||||||
|  |                 age_groups=age_groups, | ||||||
|  |                 offset=offset, | ||||||
|  |                 precomputed_idx=precomputed_idx_subset, | ||||||
|  |                 n_bootstrap=n_bootstrap, | ||||||
|  |                 use_delong=True, | ||||||
|  |             ) | ||||||
|  |             if out is None: | ||||||
|  |                 continue | ||||||
|  |             for out_item in out: | ||||||
|  |                 out_item["sex"] = sex | ||||||
|  |                 all_aucs.append(out_item) | ||||||
|  |     return all_aucs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # New internal function that performs the AUC evaluation pipeline. | ||||||
|  | def evaluate_auc_pipeline( | ||||||
|  |     model, | ||||||
|  |     d100k, | ||||||
|  |     output_path, | ||||||
|  |     delphi_labels, | ||||||
|  |     diseases_of_interest=None, | ||||||
|  |     filter_min_total=100, | ||||||
|  |     disease_chunk_size=200, | ||||||
|  |     age_groups=np.arange(40, 80, 5), | ||||||
|  |     offset=0.1, | ||||||
|  |     batch_size=256, | ||||||
|  |     device="cpu", | ||||||
|  |     seed=1337, | ||||||
|  |     n_bootstrap=1, | ||||||
|  |     meta_info={}, | ||||||
|  |     n_jobs=-1, | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     Runs the AUC evaluation pipeline. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         model (torch.nn.Module): The loaded model set to eval(). | ||||||
|  |         d100k (tuple): Data batch from get_batch. | ||||||
|  |         delphi_labels (pd.DataFrame): DataFrame with label info (token names, etc. "delphi_labels_chapters_colours_icd.csv"). | ||||||
|  |         output_path (str | None): Directory where CSV files will be written. If None, files will not be saved. | ||||||
|  |         diseases_of_interest (np.ndarray or list, optional): If provided, these disease indices are used. | ||||||
|  |         filter_min_total (int): Minimum total token count to include a token. | ||||||
|  |         disease_chunk_size (int): Maximum chunk size for processing diseases. | ||||||
|  |         age_groups (np.ndarray): Age groups to use in calibration. | ||||||
|  |         offset (float): Offset used in get_calibration_auc. | ||||||
|  |         batch_size (int): Batch size for model forwarding. | ||||||
|  |         device (str): Device identifier. | ||||||
|  |         seed (int): Random seed for reproducibility. | ||||||
|  |         n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap) | ||||||
|  |         n_jobs (int): Number of parallel jobs to run. | ||||||
|  |     Returns: | ||||||
|  |         tuple: (df_auc_unpooled, df_auc, df_both) DataFrames. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     assert n_bootstrap > 0, "n_bootstrap must be greater than 0" | ||||||
|  |  | ||||||
|  |     # Set random seeds | ||||||
|  |     torch.manual_seed(seed) | ||||||
|  |     torch.cuda.manual_seed(seed) | ||||||
|  |  | ||||||
|  |     # Get common diseases | ||||||
|  |     if diseases_of_interest is None: | ||||||
|  |         diseases_of_interest = get_common_diseases(delphi_labels, filter_min_total) | ||||||
|  |  | ||||||
|  |     # Split diseases into chunks for processing | ||||||
|  |     num_chunks = (len(diseases_of_interest) + disease_chunk_size - 1) // disease_chunk_size | ||||||
|  |     diseases_chunks = np.array_split(diseases_of_interest, num_chunks) | ||||||
|  |  | ||||||
|  |     # Precompute prediction indices for calibration | ||||||
|  |     pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1 | ||||||
|  |  | ||||||
|  |     p100k = [] | ||||||
|  |     model.to(device) | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         for dd in tqdm( | ||||||
|  |             zip(*[torch.split(x, batch_size) for x in d100k]), | ||||||
|  |             desc=f"Model inference", | ||||||
|  |             total=d100k[0].shape[0] // batch_size + 1, | ||||||
|  |         ): | ||||||
|  |             dd = [x.to(device) for x in dd] | ||||||
|  |             outputs = model(dd[0], dd[1]).cpu().detach().numpy() | ||||||
|  |             p100k.append(outputs.astype("float16")) | ||||||
|  |     p100k = np.vstack(p100k) | ||||||
|  |      | ||||||
|  |     results = Parallel(n_jobs=n_jobs)( | ||||||
|  |         delayed(process_chunk)( | ||||||
|  |             disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap | ||||||
|  |         ) | ||||||
|  |         for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     all_aucs = [item for sublist in results for item in sublist] | ||||||
|  |  | ||||||
|  |     df_auc_unpooled = pd.DataFrame(all_aucs) | ||||||
|  |  | ||||||
|  |     for key, value in meta_info.items(): | ||||||
|  |         df_auc_unpooled[key] = value | ||||||
|  |  | ||||||
|  |     delphi_labels_subset = delphi_labels[['index', 'ICD-10 Chapter (short)', 'name', 'color', 'count']] | ||||||
|  |     df_auc_unpooled_merged = df_auc_unpooled.merge(delphi_labels_subset, left_on="token", right_on="index", how="inner") | ||||||
|  |  | ||||||
|  |     def aggregate_age_brackets_delong(group): | ||||||
|  |         # For normal distributions, when averaging n of them: | ||||||
|  |         # The variance of the sum is the sum of variances | ||||||
|  |         # The variance of the average is the sum of variances divided by n^2 | ||||||
|  |         n = len(group) | ||||||
|  |         mean = group['auc_delong'].mean() | ||||||
|  |         # Since we're taking the average, divide combined variance by n^2 | ||||||
|  |         var = group['auc_variance_delong'].sum() / (n**2) | ||||||
|  |         return pd.Series({ | ||||||
|  |             'auc': mean, | ||||||
|  |             'auc_variance_delong': var, | ||||||
|  |             'n_samples': n,  | ||||||
|  |             'n_diseased': group['n_diseased'].sum(), | ||||||
|  |             'n_healthy': group['n_healthy'].sum(), | ||||||
|  |         }) | ||||||
|  |  | ||||||
|  |     print('Using DeLong method to calculate AUC confidence intervals..') | ||||||
|  |      | ||||||
|  |     df_auc = df_auc_unpooled.groupby(["token"]).apply(aggregate_age_brackets_delong).reset_index() | ||||||
|  |     df_auc_merged = df_auc.merge(delphi_labels, left_on="token", right_on="index", how="inner") | ||||||
|  |      | ||||||
|  |     if output_path is not None: | ||||||
|  |         Path(output_path).mkdir(exist_ok=True, parents=True) | ||||||
|  |         df_auc_merged.to_csv(f"{output_path}/df_both.csv", index=False) | ||||||
|  |         df_auc_unpooled_merged.to_csv(f"{output_path}/df_auc_unpooled.csv", index=False) | ||||||
|  |  | ||||||
|  |     return df_auc_unpooled_merged, df_auc_merged | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     parser = argparse.ArgumentParser(description="Evaluate AUC") | ||||||
|  |     parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name") | ||||||
|  |     parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation") | ||||||
|  |     parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples") | ||||||
|  |     parser.add_argument("--offset", type=float, default=365.25, help="Offset in days for prediction") | ||||||
|  |     # Optional filtering/chunking parameters: | ||||||
|  |     parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens") | ||||||
|  |     parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases") | ||||||
|  |     parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     model_name = args.model_name | ||||||
|  |     output_path = f'auc_evaluation_{model_name}' | ||||||
|  |     dataset_subset_size = args.dataset_subset_size | ||||||
|  |     offset = args.offset | ||||||
|  |  | ||||||
|  |     # Create output folder if it doesn't exist. | ||||||
|  |     Path(output_path).mkdir(exist_ok=True, parents=True) | ||||||
|  |  | ||||||
|  |     device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |     seed = 1337 | ||||||
|  |  | ||||||
|  |     # Load model checkpoint and initialize model. | ||||||
|  |     model = load_model( | ||||||
|  |         config_path=f'config_{model_name}.json', | ||||||
|  |         device=device, | ||||||
|  |     ) | ||||||
|  |     model.eval() | ||||||
|  |     model = model.to(device) | ||||||
|  |  | ||||||
|  |     # Load training and validation data. | ||||||
|  |  | ||||||
|  |      | ||||||
|  |     val_data_path = 'ukb_real_val.bin' | ||||||
|  |      | ||||||
|  |     val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) | ||||||
|  |     block_length = 128     | ||||||
|  |     val_dataset = PatientEventDataset(val_data_arr, block_length) | ||||||
|  |  | ||||||
|  |     if dataset_subset_size == -1: | ||||||
|  |         dataset_subset_size = len(val_dataset) | ||||||
|  |  | ||||||
|  |     # Get a subset batch for evaluation. | ||||||
|  |     d100k = get_batch(val_dataset, slice(dataset_subset_size)) | ||||||
|  |  | ||||||
|  |     # Load labels (external) to be passed in. | ||||||
|  |     delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv") | ||||||
|  |  | ||||||
|  |     # Call the internal evaluation function. | ||||||
|  |     df_auc_unpooled, df_auc_merged = evaluate_auc_pipeline( | ||||||
|  |         model, | ||||||
|  |         d100k, | ||||||
|  |         output_path, | ||||||
|  |         delphi_labels, | ||||||
|  |         diseases_of_interest=None, | ||||||
|  |         filter_min_total=args.filter_min_total, | ||||||
|  |         disease_chunk_size=args.disease_chunk_size, | ||||||
|  |         device=device, | ||||||
|  |         seed=seed, | ||||||
|  |         offset=offset, | ||||||
|  |         n_bootstrap=args.n_bootstrap, | ||||||
|  |         n_jobs=args.n_jobs, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
							
								
								
									
										718
									
								
								evaluate_models.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										718
									
								
								evaluate_models.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										565
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										565
									
								
								models.py
									
									
									
									
									
								
							| @@ -1,7 +1,83 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from torch.nn import functional as F | from torch.nn import functional as F | ||||||
| from typing import Tuple | from typing import Tuple, Optional | ||||||
|  | import math | ||||||
|  |  | ||||||
|  | # ============================================================================= | ||||||
|  | # 1. Component Modules (Building Blocks) | ||||||
|  | # ============================================================================= | ||||||
|  | class CausalConv1d(nn.Module): | ||||||
|  |     def __init__(self, channels, kernel_size, groups=1): | ||||||
|  |         super().__init__() | ||||||
|  |         self.pad = kernel_size - 1 | ||||||
|  |         self.conv = nn.Conv1d( | ||||||
|  |             channels, channels, kernel_size, | ||||||
|  |             padding=0, groups=groups | ||||||
|  |         ) | ||||||
|  |     def forward(self, x):  # x: (B, C, L) | ||||||
|  |         x = F.pad(x, (self.pad, 0))  # pad only on the left to ensure causality | ||||||
|  |         x = x.contiguous() | ||||||
|  |         return self.conv(x) | ||||||
|  |      | ||||||
|  | class DepthwiseSeparableCausalConvBlock(nn.Module): | ||||||
|  |     def __init__(self, d_model, kernel_size=5, dropout=0.1): | ||||||
|  |         super().__init__() | ||||||
|  |         self.dw = CausalConv1d(d_model, kernel_size, groups=d_model)   # depthwise | ||||||
|  |         self.pw = nn.Conv1d(d_model, d_model, 1)                       # pointwise | ||||||
|  |         self.act = nn.GELU() | ||||||
|  |         self.ln = nn.LayerNorm(d_model) | ||||||
|  |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|  |     def forward(self, x):  # x: (B, L, D) | ||||||
|  |         y = x.transpose(1, 2).contiguous()  # (B, D, L) | ||||||
|  |         y = self.dw(y)                      # (B, D, L) | ||||||
|  |         y = self.pw(y.contiguous())         # (B, D, L) | ||||||
|  |         y = y.transpose(1, 2).contiguous()  # (B, L, D) | ||||||
|  |         y = self.act(y) | ||||||
|  |         y = self.dropout(y) | ||||||
|  |         return self.ln(x + y)              # residual connection + layer norm (LN) | ||||||
|  |      | ||||||
|  | class TimeFeatureProjector(nn.Module): | ||||||
|  |     """ | ||||||
|  |     Projects scalar time t and its increment Δt into d_model dimensions. | ||||||
|  |     Combines: linear-scale features + fixed-frequency sin/cos (Fourier time features). | ||||||
|  |     """ | ||||||
|  |     def __init__(self, d_model, fourier_dim=32, dt_clip=1e6): | ||||||
|  |         super().__init__() | ||||||
|  |         self.dt_clip = dt_clip | ||||||
|  |         self.scalar_proj = nn.Linear(2, d_model, bias=False)  # [t_scaled, dt_scaled] -> D | ||||||
|  |  | ||||||
|  |         # Predefine a set of logarithmically spaced frequencies (tune for your time units if needed) | ||||||
|  |         k = fourier_dim // 2 | ||||||
|  |         freqs = torch.logspace(-4, 2, steps=k) * 2 * math.pi  # frequency coverage ~1e-4 to 1e2 | ||||||
|  |         self.register_buffer("freqs", freqs, persistent=False) | ||||||
|  |  | ||||||
|  |         self.fourier_proj = nn.Linear(2*k, d_model, bias=False)  # [sin, cos] -> D | ||||||
|  |         self.gate = nn.Parameter(torch.zeros(1))                 # learnable gate to smoothly introduce Fourier features | ||||||
|  |         self.ln = nn.LayerNorm(d_model) | ||||||
|  |  | ||||||
|  |     def forward(self, t):  # t: (B, L)  continuous timestamps/steps | ||||||
|  |         # compute increments Δt and stabilize | ||||||
|  |         dt = t - F.pad(t, (1, 0), value=0.)[:, :-1] | ||||||
|  |         dt = torch.clamp(dt, min=0.)  # ensure non-negative | ||||||
|  |         # normalize/stabilize with log compression | ||||||
|  |         t_scaled  = torch.log1p(torch.clamp(torch.abs(t),  max=self.dt_clip)) | ||||||
|  |         dt_scaled = torch.log1p(torch.clamp(dt,            max=self.dt_clip)) | ||||||
|  |  | ||||||
|  |         scal = torch.stack([t_scaled, dt_scaled], dim=-1)  # (B, L, 2) | ||||||
|  |         scal_feat = self.scalar_proj(scal)                 # (B, L, D) | ||||||
|  |  | ||||||
|  |         # Fixed-frequency sin/cos to capture absolute/relative periodicity | ||||||
|  |         # If t is in steps, use directly; if in seconds, ensure units are consistent (e.g., divide by a time constant) | ||||||
|  |         # (B, L, K) | ||||||
|  |         wt = t[..., None] * self.freqs | ||||||
|  |         sincos = torch.cat([torch.sin(wt), torch.cos(wt)], dim=-1)  # (B, L, 2K) | ||||||
|  |         fourier_feat = self.fourier_proj(sincos)                    # (B, L, D) | ||||||
|  |  | ||||||
|  |         # gated fusion + layer norm | ||||||
|  |         h = scal_feat + torch.tanh(self.gate) * fourier_feat | ||||||
|  |         return self.ln(h)  # (B, L, D) | ||||||
|  |  | ||||||
| class Block(nn.Module): | class Block(nn.Module): | ||||||
|     """ an unassuming Transformer block """ |     """ an unassuming Transformer block """ | ||||||
| @@ -25,8 +101,10 @@ class Block(nn.Module): | |||||||
|     def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: |     def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: | ||||||
|         normed_x = self.ln_1(x) |         normed_x = self.ln_1(x) | ||||||
|          |          | ||||||
|         attn_mask = ~custom_mask |         # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs | ||||||
|         attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) |         # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked. | ||||||
|  |         mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0)  # True where we want to mask | ||||||
|  |         attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9) | ||||||
|  |  | ||||||
|         attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) |         attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) | ||||||
|         x = x + self.resid_dropout(attn_output) |         x = x + self.resid_dropout(attn_output) | ||||||
| @@ -58,14 +136,8 @@ class AgeSinusoidalEncoding(nn.Module): | |||||||
|         self.embedding_dim = embedding_dim |         self.embedding_dim = embedding_dim | ||||||
|  |  | ||||||
|         # Pre-calculate the divisor term for the sinusoidal formula. |         # Pre-calculate the divisor term for the sinusoidal formula. | ||||||
|         # The formula for the divisor is 10000^(2i/D), where D is the |  | ||||||
|         # embedding_dim and i is the index for each pair of dimensions. |  | ||||||
|         # i ranges from 0 to D/2 - 1. |  | ||||||
|         i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) |         i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) | ||||||
|         divisor = torch.pow(10000, i / self.embedding_dim) |         divisor = torch.pow(10000, i / self.embedding_dim) | ||||||
|  |  | ||||||
|         # Register the divisor as a non-trainable buffer. This ensures it is |  | ||||||
|         # moved to the correct device (e.g., GPU) along with the model. |  | ||||||
|         self.register_buffer('divisor', divisor) |         self.register_buffer('divisor', divisor) | ||||||
|  |  | ||||||
|     def forward(self, t: torch.Tensor) -> torch.Tensor: |     def forward(self, t: torch.Tensor) -> torch.Tensor: | ||||||
| @@ -80,49 +152,204 @@ class AgeSinusoidalEncoding(nn.Module): | |||||||
|             torch.Tensor: The encoded age tensor of shape |             torch.Tensor: The encoded age tensor of shape | ||||||
|                 (batch_size, sequence_length, embedding_dim). |                 (batch_size, sequence_length, embedding_dim). | ||||||
|         """ |         """ | ||||||
|         # 1. Unit Conversion: Convert age from days to years. |  | ||||||
|         # We use 365.25 to account for leap years. |  | ||||||
|         t_years = t / 365.25 |         t_years = t / 365.25 | ||||||
|  |  | ||||||
|         # 2. Argument Calculation: Calculate the arguments for the sin/cos functions. |  | ||||||
|         # The shapes are broadcast to (B, L, D/2). |  | ||||||
|         # Input t_years: (B, L) -> unsqueezed to (B, L, 1) |  | ||||||
|         # Divisor: (D/2) -> viewed as (1, 1, D/2) |  | ||||||
|         args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) |         args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) | ||||||
|  |  | ||||||
|         # 3. Sinusoidal Application: Create the final output tensor. |  | ||||||
|         # Initialize an empty tensor to store the embeddings. |  | ||||||
|         output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) |         output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) | ||||||
|  |  | ||||||
|         # Assign cosine of the arguments to the even indices. |  | ||||||
|         output[:, :, 0::2] = torch.cos(args) |         output[:, :, 0::2] = torch.cos(args) | ||||||
|          |  | ||||||
|         # Assign sine of the arguments to the odd indices. |  | ||||||
|         output[:, :, 1::2] = torch.sin(args) |         output[:, :, 1::2] = torch.sin(args) | ||||||
|  |  | ||||||
|         return output |         return output | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LearnableAgeEncoding(nn.Module): | ||||||
|  |     """Combines fixed sinusoidal age encodings with a learnable MLP projection.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, base_dim: int, hidden_dim: Optional[int] = None, final_dim: Optional[int] = None, dropout: float = 0.0): | ||||||
|  |         super().__init__() | ||||||
|  |         self.base_dim = base_dim | ||||||
|  |         self.final_dim = final_dim or base_dim | ||||||
|  |  | ||||||
|  |         hidden_dim = hidden_dim or base_dim | ||||||
|  |         if hidden_dim <= 0: | ||||||
|  |             raise ValueError("hidden_dim must be a positive integer.") | ||||||
|  |         if self.final_dim <= 0: | ||||||
|  |             raise ValueError("final_dim must be a positive integer.") | ||||||
|  |  | ||||||
|  |         self.sinusoidal = AgeSinusoidalEncoding(base_dim) | ||||||
|  |  | ||||||
|  |         mlp_layers = [ | ||||||
|  |             nn.Linear(base_dim, hidden_dim), | ||||||
|  |             nn.GELU(), | ||||||
|  |         ] | ||||||
|  |         if dropout > 0.0: | ||||||
|  |             mlp_layers.append(nn.Dropout(dropout)) | ||||||
|  |         mlp_layers.append(nn.Linear(hidden_dim, self.final_dim)) | ||||||
|  |  | ||||||
|  |         self.mlp = nn.Sequential(*mlp_layers) | ||||||
|  |  | ||||||
|  |     def forward(self, t: torch.Tensor) -> torch.Tensor: | ||||||
|  |         sin_embed = self.sinusoidal(t) | ||||||
|  |         flat_embed = sin_embed.reshape(-1, self.base_dim) | ||||||
|  |         projected = self.mlp(flat_embed) | ||||||
|  |         return projected.reshape(*sin_embed.shape[:-1], self.final_dim) | ||||||
|  |  | ||||||
|  | class PiecewiseLinearEncoder(nn.Module): | ||||||
|  |     """ | ||||||
|  |     Encodes continuous variables using piecewise linear encoding. | ||||||
|  |  | ||||||
|  |     This module defines bins based on standard normal distribution quantiles, | ||||||
|  |     encodes an input by finding its bin, and calculates its position as a | ||||||
|  |     linear interpolation between boundaries. The result is projected to the | ||||||
|  |     final embedding dimension by a shared linear layer. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, num_bins: int, embedding_dim: int): | ||||||
|  |         """ | ||||||
|  |         Initializes the PiecewiseLinearEncoder module. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             num_bins (int): The number of bins for the encoding. | ||||||
|  |             embedding_dim (int): The dimensionality of the output embedding (D). | ||||||
|  |         """ | ||||||
|  |         super().__init__() | ||||||
|  |         if num_bins <= 0: | ||||||
|  |             raise ValueError("num_bins must be a positive integer.") | ||||||
|  |         self.num_bins = num_bins | ||||||
|  |         self.embedding_dim = embedding_dim | ||||||
|  |  | ||||||
|  |         if num_bins > 1: | ||||||
|  |             quantiles = torch.linspace(1.0 / num_bins, (num_bins - 1.0) / num_bins, num_bins - 1) | ||||||
|  |             normal_dist = torch.distributions.normal.Normal(0, 1) | ||||||
|  |             boundaries = normal_dist.icdf(quantiles) | ||||||
|  |         else: | ||||||
|  |             boundaries = torch.tensor([]) | ||||||
|  |  | ||||||
|  |         boundaries = torch.cat([ | ||||||
|  |             torch.tensor([float('-inf')]), | ||||||
|  |             boundaries, | ||||||
|  |             torch.tensor([float('inf')]) | ||||||
|  |         ]) | ||||||
|  |         self.register_buffer('boundaries', boundaries) | ||||||
|  |  | ||||||
|  |         self.linear = nn.Linear(num_bins, embedding_dim) | ||||||
|  |  | ||||||
|  |     def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Forward pass for the piecewise linear encoding. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             x (torch.Tensor): Input tensor of shape (*, N), where * is any | ||||||
|  |                 number of batch dimensions and N is the number of continuous | ||||||
|  |                 features. Assumed to be pre-scaled. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             torch.Tensor: Encoded tensor of shape (*, N, D). | ||||||
|  |         """ | ||||||
|  |         original_shape = x.shape | ||||||
|  |         x = x.reshape(-1, original_shape[-1]) | ||||||
|  |  | ||||||
|  |         bin_indices = torch.searchsorted(self.boundaries, x, right=True) - 1 | ||||||
|  |         bin_indices = bin_indices.clamp(0, self.num_bins - 1) | ||||||
|  |  | ||||||
|  |         lower_bounds = self.boundaries[bin_indices] | ||||||
|  |         upper_bounds = self.boundaries[bin_indices + 1] | ||||||
|  |         delta = upper_bounds - lower_bounds + 1e-8 | ||||||
|  |  | ||||||
|  |         weight_upper = (x - lower_bounds) / delta | ||||||
|  |         weight_lower = 1.0 - weight_upper | ||||||
|  |  | ||||||
|  |         is_first_bin = (bin_indices == 0) | ||||||
|  |         is_last_bin = (bin_indices == self.num_bins - 1) | ||||||
|  |          | ||||||
|  |         weight_lower[is_first_bin] = 1.0 | ||||||
|  |         weight_upper[is_first_bin] = 0.0 | ||||||
|  |         weight_lower[is_last_bin] = 0.0 | ||||||
|  |         weight_upper[is_last_bin] = 1.0 | ||||||
|  |  | ||||||
|  |         encoded = torch.zeros(*x.shape, self.num_bins, device=x.device, dtype=x.dtype) | ||||||
|  |         encoded.scatter_(-1, bin_indices.unsqueeze(-1), weight_lower.unsqueeze(-1)) | ||||||
|  |          | ||||||
|  |         upper_indices = (bin_indices + 1).clamp(max=self.num_bins - 1) | ||||||
|  |         encoded.scatter_add_(-1, upper_indices.unsqueeze(-1), weight_upper.unsqueeze(-1)) | ||||||
|  |  | ||||||
|  |         encoded = encoded.view(*original_shape, self.num_bins) | ||||||
|  |         output = self.linear(encoded) | ||||||
|  |         return output | ||||||
|  |      | ||||||
|  | class TemporalConvEncoder(nn.Module): | ||||||
|  |     """ | ||||||
|  |         Inputs: | ||||||
|  |             x: (B, L)   - event/token ids | ||||||
|  |             t: (B, L)   - timestamps (real-valued) or step indices | ||||||
|  |         Output: | ||||||
|  |             h: (B, L, D) - can be fed directly as Transformer/GPT-2 inputs_embeds | ||||||
|  |     """ | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         vocab_size: int, | ||||||
|  |         d_model: int = 768, | ||||||
|  |         n_layers: int = 2, | ||||||
|  |         kernel_size: int = 5, | ||||||
|  |         dropout: float = 0.1, | ||||||
|  |         fourier_dim: int = 32, | ||||||
|  |         pad_id: int = 0 | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) | ||||||
|  |         self.time_proj = TimeFeatureProjector(d_model, fourier_dim=fourier_dim) | ||||||
|  |         self.fuse = nn.Linear(2*d_model, d_model, bias=False)  # fuse token and time features | ||||||
|  |         self.ln_in = nn.LayerNorm(d_model) | ||||||
|  |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|  |         blocks = [] | ||||||
|  |         for _ in range(n_layers): | ||||||
|  |             blocks.append(DepthwiseSeparableCausalConvBlock(d_model, kernel_size, dropout)) | ||||||
|  |         self.blocks = nn.ModuleList(blocks) | ||||||
|  |  | ||||||
|  |     def forward(self, x, t, attention_mask=None): | ||||||
|  |         """ | ||||||
|  |         attention_mask: (B, L)  1=keep, 0=padding | ||||||
|  |         """ | ||||||
|  |         tok = self.token_emb(x)          # (B, L, D) | ||||||
|  |         tim = self.time_proj(t)          # (B, L, D) | ||||||
|  |  | ||||||
|  |         h = torch.cat([tok, tim], dim=-1)  # (B, L, 2D) | ||||||
|  |         h = self.fuse(h)                   # (B, L, D) | ||||||
|  |         h = self.ln_in(h) | ||||||
|  |         h = self.dropout(h) | ||||||
|  |  | ||||||
|  |         # Optional: zero-out padding positions before convolutions to avoid leakage | ||||||
|  |         if attention_mask is not None: | ||||||
|  |             h = h * attention_mask.unsqueeze(-1).type_as(h) | ||||||
|  |  | ||||||
|  |         # Multi-layer causal temporal convolutions (no look-ahead) to form relative position-aware context | ||||||
|  |         for blk in self.blocks: | ||||||
|  |             h = blk(h)  # (B, L, D) | ||||||
|  |  | ||||||
|  |             if attention_mask is not None: | ||||||
|  |                 h = h * attention_mask.unsqueeze(-1).type_as(h) | ||||||
|  |  | ||||||
|  |         return h  # (B, L, D), directly usable as attention layer input | ||||||
|  |  | ||||||
|  | # ============================================================================= | ||||||
|  | # 2. Main Model Architectures | ||||||
|  | # ============================================================================= | ||||||
|  |  | ||||||
| class TimeAwareGPT2(nn.Module): | class TimeAwareGPT2(nn.Module): | ||||||
|     """ |     """ | ||||||
|     A time-aware GPT-2 model with custom temporal features. |     A time-aware GPT-2 model with custom temporal features. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): |     def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.token_pdrop = token_pdrop |         self.token_pdrop = token_pdrop | ||||||
|  |         self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] | ||||||
|  |  | ||||||
|         # Token and positional embeddings |  | ||||||
|         self.wte = nn.Embedding(vocab_size, n_embd) |         self.wte = nn.Embedding(vocab_size, n_embd) | ||||||
|         self.age_encoder = AgeSinusoidalEncoding(n_embd) |         self.age_encoder = AgeSinusoidalEncoding(n_embd) | ||||||
|         self.drop = nn.Dropout(pdrop) |         self.drop = nn.Dropout(pdrop) | ||||||
|  |  | ||||||
|         # Transformer blocks |  | ||||||
|         self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) |         self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) | ||||||
|          |  | ||||||
|         # Final layer norm and linear head |  | ||||||
|         self.ln_f = nn.LayerNorm(n_embd) |         self.ln_f = nn.LayerNorm(n_embd) | ||||||
|         self.head = nn.Linear(n_embd, vocab_size, bias=False) |         self.head = nn.Linear(n_embd, vocab_size, bias=False) | ||||||
|  |  | ||||||
|         self.n_embd = n_embd |         self.n_embd = n_embd | ||||||
|  |  | ||||||
|     def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: |     def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: | ||||||
| @@ -138,46 +365,30 @@ class TimeAwareGPT2(nn.Module): | |||||||
|         """ |         """ | ||||||
|         B, L = event_seq.size() |         B, L = event_seq.size() | ||||||
|  |  | ||||||
|         # 1. Get token embeddings |  | ||||||
|         token_embeddings = self.wte(event_seq) |         token_embeddings = self.wte(event_seq) | ||||||
|  |  | ||||||
|         # 2. Apply token dropout (only during training) |  | ||||||
|         if self.training and self.token_pdrop > 0: |         if self.training and self.token_pdrop > 0: | ||||||
|             # Create a mask to randomly zero out entire token embedding vectors |  | ||||||
|             drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop |             drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop | ||||||
|             token_embeddings[drop_mask] = 0.0 |             token_embeddings[drop_mask] = 0.0 | ||||||
|  |  | ||||||
|         # 3. Get positional embeddings from time sequence |  | ||||||
|         pos_embeddings = self.age_encoder(time_seq.float()) |         pos_embeddings = self.age_encoder(time_seq.float()) | ||||||
|  |  | ||||||
|         # 4. Combine embeddings and apply dropout |  | ||||||
|         x = self.drop(token_embeddings + pos_embeddings) |         x = self.drop(token_embeddings + pos_embeddings) | ||||||
|  |  | ||||||
|         # 5. Generate attention mask |         t_i = time_seq.unsqueeze(-1) | ||||||
|         # The attention mask combines two conditions: |         t_j = time_seq.unsqueeze(1) | ||||||
|         # a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i]. |         time_mask = (t_j < t_i) | ||||||
|         # b) Padding mask: Do not attend to positions where the event token is 0. |         padding_mask = (event_seq != 0).unsqueeze(1) | ||||||
|  |  | ||||||
|         # a) Time-based causal mask |  | ||||||
|         t_i = time_seq.unsqueeze(-1)  # (B, L, 1) |  | ||||||
|         t_j = time_seq.unsqueeze(1)   # (B, 1, L) |  | ||||||
|         time_mask = (t_j <= t_i) |  | ||||||
|  |  | ||||||
|         # b) Padding mask (prevents attending to key positions that are padding) |  | ||||||
|         padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) |  | ||||||
|          |  | ||||||
|         # Combine the masks. A position (j) can be attended to by a query (i) only if |  | ||||||
|         # it's in the past (time_mask) AND it's not a padding token (padding_mask). |  | ||||||
|         combined_mask = time_mask & padding_mask |         combined_mask = time_mask & padding_mask | ||||||
|  |  | ||||||
|         # 6. Pass through transformer blocks |         is_row_all_zero = ~combined_mask.any(dim=-1) | ||||||
|  |         is_not_padding = (event_seq != 0) | ||||||
|  |         force_self_attention = is_row_all_zero & is_not_padding | ||||||
|  |         combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True | ||||||
|  |  | ||||||
|         for block in self.blocks: |         for block in self.blocks: | ||||||
|             x = block(x, custom_mask=combined_mask) |             x = block(x, custom_mask=combined_mask) | ||||||
|  |  | ||||||
|         # 7. Final layer norm and projection to vocab size |  | ||||||
|         x = self.ln_f(x) |         x = self.ln_f(x) | ||||||
|         logits = self.head(x) |         logits = self.head(x) | ||||||
|  |  | ||||||
|         return logits |         return logits | ||||||
|  |  | ||||||
|     def get_num_params(self) -> float: |     def get_num_params(self) -> float: | ||||||
| @@ -186,6 +397,222 @@ class TimeAwareGPT2(nn.Module): | |||||||
|         """ |         """ | ||||||
|         return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 |         return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None): | ||||||
|  |         """ | ||||||
|  |         Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete | ||||||
|  |         the sequence max_new_tokens times, feeding the predictions back into the model each time. | ||||||
|  |         Most likely you'll want to make sure to be in model.eval() mode of operation for this. | ||||||
|  |         """ | ||||||
|  |         self.eval() | ||||||
|  |  | ||||||
|  |         if termination_tokens is None: | ||||||
|  |             termination_tokens = [1269] | ||||||
|  |          | ||||||
|  |         termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) | ||||||
|  |         mask_time = -10000 | ||||||
|  |  | ||||||
|  |         for _ in range(max_new_tokens): | ||||||
|  |             logits = self(x, t) | ||||||
|  |             logits = logits[:, -1, :] | ||||||
|  |              | ||||||
|  |             if self.ignore_tokens: | ||||||
|  |                 logits[:, self.ignore_tokens] = -torch.inf | ||||||
|  |  | ||||||
|  |             if no_repeat: | ||||||
|  |                 fill = x.clone() | ||||||
|  |                 fill[fill == 1] = 0 | ||||||
|  |                 logits = logits.scatter(1, fill, -torch.inf) | ||||||
|  |              | ||||||
|  |             t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80) | ||||||
|  |             t_next_val, idx_next = t_next_dist.min(1) | ||||||
|  |              | ||||||
|  |             idx_next = idx_next.unsqueeze(1) | ||||||
|  |             age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) | ||||||
|  |              | ||||||
|  |             x = torch.cat((x, idx_next), dim=1) | ||||||
|  |             t = torch.cat((t, age_next), dim=1) | ||||||
|  |              | ||||||
|  |             if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): | ||||||
|  |                 break | ||||||
|  |          | ||||||
|  |         pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) | ||||||
|  |  | ||||||
|  |         final_logits = self(x, t) | ||||||
|  |         x[pad] = 0 | ||||||
|  |         t[pad] = mask_time | ||||||
|  |  | ||||||
|  |         if no_repeat: | ||||||
|  |             fill = x.clone() | ||||||
|  |             fill[fill == 1] = 0 | ||||||
|  |             final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1) | ||||||
|  |  | ||||||
|  |         return x, t, final_logits | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TimeAwareGPT2Learnable(TimeAwareGPT2): | ||||||
|  |     """Variant of TimeAwareGPT2 that uses LearnableAgeEncoding for temporal features.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.age_encoder = LearnableAgeEncoding( | ||||||
|  |             base_dim=self.n_embd, | ||||||
|  |             hidden_dim=2 * self.n_embd, | ||||||
|  |             final_dim=self.n_embd, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ============================================================================= | ||||||
|  | # 3. Loss Function | ||||||
|  | # ============================================================================= | ||||||
|  |  | ||||||
|  | class TimeAwareGPT2TemporalConv(nn.Module): | ||||||
|  |     """ | ||||||
|  |     A TimeAware GPT-2 variant that uses TemporalConvEncoder to encode | ||||||
|  |     event and time sequences before Transformer attention blocks. | ||||||
|  |  | ||||||
|  |     Inputs: | ||||||
|  |       - event_seq: (B, L) token ids (0 treated as padding) | ||||||
|  |       - time_seq:  (B, L) timestamps or step indices (float) | ||||||
|  |  | ||||||
|  |     Output: | ||||||
|  |       - logits: (B, L, vocab_size) | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         vocab_size: int, | ||||||
|  |         n_embd: int, | ||||||
|  |         n_layer: int, | ||||||
|  |         n_head: int, | ||||||
|  |         pdrop: float, | ||||||
|  |         token_pdrop: float, | ||||||
|  |         ignore_tokens: Optional[list[int]] = None, | ||||||
|  |         *, | ||||||
|  |         conv_layers: int = 2, | ||||||
|  |         kernel_size: int = 5, | ||||||
|  |         conv_dropout: float = 0.1, | ||||||
|  |         fourier_dim: int = 32, | ||||||
|  |         pad_id: int = 0, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.token_pdrop = token_pdrop | ||||||
|  |         self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] | ||||||
|  |         self.n_embd = n_embd | ||||||
|  |  | ||||||
|  |         # Temporal convolutional encoder to build inputs_embeds | ||||||
|  |         self.temporal_encoder = TemporalConvEncoder( | ||||||
|  |             vocab_size=vocab_size, | ||||||
|  |             d_model=n_embd, | ||||||
|  |             n_layers=conv_layers, | ||||||
|  |             kernel_size=kernel_size, | ||||||
|  |             dropout=conv_dropout, | ||||||
|  |             fourier_dim=fourier_dim, | ||||||
|  |             pad_id=pad_id, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Transformer stack on top of temporal features | ||||||
|  |         self.drop = nn.Dropout(pdrop) | ||||||
|  |         self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) | ||||||
|  |         self.ln_f = nn.LayerNorm(n_embd) | ||||||
|  |         self.head = nn.Linear(n_embd, vocab_size, bias=False) | ||||||
|  |  | ||||||
|  |     def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: | ||||||
|  |         B, L = event_seq.size() | ||||||
|  |  | ||||||
|  |         # Encoder features as inputs_embeds | ||||||
|  |         attention_mask = (event_seq != 0) | ||||||
|  |         x = self.temporal_encoder(event_seq, time_seq.float(), attention_mask=attention_mask) | ||||||
|  |         x = self.drop(x) | ||||||
|  |  | ||||||
|  |         # Time-aware causal mask as before | ||||||
|  |         t_i = time_seq.unsqueeze(-1) | ||||||
|  |         t_j = time_seq.unsqueeze(1) | ||||||
|  |         time_mask = (t_j < t_i) | ||||||
|  |         padding_mask = (event_seq != 0).unsqueeze(1) | ||||||
|  |         combined_mask = time_mask & padding_mask | ||||||
|  |  | ||||||
|  |         # Ensure at least self-attention on non-padding rows | ||||||
|  |         is_row_all_zero = ~combined_mask.any(dim=-1) | ||||||
|  |         is_not_padding = (event_seq != 0) | ||||||
|  |         force_self_attention = is_row_all_zero & is_not_padding | ||||||
|  |         combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True | ||||||
|  |  | ||||||
|  |         for block in self.blocks: | ||||||
|  |             x = block(x, custom_mask=combined_mask) | ||||||
|  |  | ||||||
|  |         x = self.ln_f(x) | ||||||
|  |         logits = self.head(x) | ||||||
|  |         return logits | ||||||
|  |  | ||||||
|  |     def get_num_params(self) -> float: | ||||||
|  |         return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def generate( | ||||||
|  |         self, | ||||||
|  |         x: torch.Tensor, | ||||||
|  |         t: torch.Tensor, | ||||||
|  |         max_new_tokens: int = 100, | ||||||
|  |         max_age: float = 85 * 365.25, | ||||||
|  |         no_repeat: bool = True, | ||||||
|  |         termination_tokens: Optional[list[int]] = None, | ||||||
|  |         top_k: Optional[int] = None, | ||||||
|  |     ): | ||||||
|  |         """Greedy-like generation with optional no-repeat and termination tokens.""" | ||||||
|  |         self.eval() | ||||||
|  |  | ||||||
|  |         if termination_tokens is None: | ||||||
|  |             termination_tokens = [1269] | ||||||
|  |  | ||||||
|  |         termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) | ||||||
|  |         mask_time = -10000 | ||||||
|  |  | ||||||
|  |         for _ in range(max_new_tokens): | ||||||
|  |             logits = self(x, t) | ||||||
|  |             logits = logits[:, -1, :] | ||||||
|  |  | ||||||
|  |             if self.ignore_tokens: | ||||||
|  |                 logits[:, self.ignore_tokens] = -torch.inf | ||||||
|  |  | ||||||
|  |             if no_repeat: | ||||||
|  |                 fill = x.clone() | ||||||
|  |                 fill[fill == 1] = 0 | ||||||
|  |                 logits = logits.scatter(1, fill, -torch.inf) | ||||||
|  |  | ||||||
|  |             # Sample a time increment proxy as in original implementation | ||||||
|  |             t_next_dist = torch.clamp( | ||||||
|  |                 -torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), | ||||||
|  |                 min=0, | ||||||
|  |                 max=365 * 80, | ||||||
|  |             ) | ||||||
|  |             t_next_val, idx_next = t_next_dist.min(1) | ||||||
|  |  | ||||||
|  |             idx_next = idx_next.unsqueeze(1) | ||||||
|  |             age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) | ||||||
|  |  | ||||||
|  |             x = torch.cat((x, idx_next), dim=1) | ||||||
|  |             t = torch.cat((t, age_next), dim=1) | ||||||
|  |  | ||||||
|  |             if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |         pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) | ||||||
|  |  | ||||||
|  |         final_logits = self(x, t) | ||||||
|  |         x[pad] = 0 | ||||||
|  |         t[pad] = mask_time | ||||||
|  |  | ||||||
|  |         if no_repeat: | ||||||
|  |             fill = x.clone() | ||||||
|  |             fill[fill == 1] = 0 | ||||||
|  |             final_logits = torch.stack( | ||||||
|  |                 [final_logits[:, j].scatter(1, fill[:, : j + 1], -torch.inf) for j in range(fill.shape[1])] | ||||||
|  |             ).transpose(0, 1) | ||||||
|  |  | ||||||
|  |         return x, t, final_logits | ||||||
|  |  | ||||||
| class CombinedLoss(nn.Module): | class CombinedLoss(nn.Module): | ||||||
|     """ |     """ | ||||||
|     Computes a two-part loss: a standard cross-entropy loss for event type |     Computes a two-part loss: a standard cross-entropy loss for event type | ||||||
| @@ -215,35 +642,23 @@ class CombinedLoss(nn.Module): | |||||||
|         Returns: |         Returns: | ||||||
|             A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). |             A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). | ||||||
|         """ |         """ | ||||||
|         # 1. Create a mask to filter out ignored token IDs from loss calculation. |  | ||||||
|         # An element is True if the corresponding label in x is NOT in the ignored list. |  | ||||||
|         mask = torch.ones_like(x, dtype=torch.bool) |         mask = torch.ones_like(x, dtype=torch.bool) | ||||||
|         for token_id in self.ignored_token_ids: |         for token_id in self.ignored_token_ids: | ||||||
|             mask = mask & (x != token_id) |             mask = mask & (x != token_id) | ||||||
|  |  | ||||||
|         # If the mask is all False (all tokens are ignored), return zero for both losses. |  | ||||||
|         if not mask.any(): |         if not mask.any(): | ||||||
|             return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) |             return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) | ||||||
|  |  | ||||||
|         # 2. Part 1: Cross-Entropy Loss (loss_ce) |  | ||||||
|         # Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy. |  | ||||||
|         logits_for_ce = logits.permute(0, 2, 1) |         logits_for_ce = logits.permute(0, 2, 1) | ||||||
|          |  | ||||||
|         # Calculate per-element loss without reduction. |  | ||||||
|         per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') |         per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') | ||||||
|          |  | ||||||
|         # Apply the mask and compute the mean of valid elements. |  | ||||||
|         loss_ce = per_element_ce[mask].mean() |         loss_ce = per_element_ce[mask].mean() | ||||||
|  |  | ||||||
|         # 3. Part 2: Survival Loss (loss_survival) |         # Survival loss based on exponential log-likelihood | ||||||
|         # Calculate event intensity (lambda) as the sum of exponentiated logits. |         t_min = 0.1 | ||||||
|         intensity = torch.sum(torch.exp(logits), dim=2) |         lse = torch.logsumexp(logits, dim=-1) | ||||||
|  |         lse = -torch.log(torch.exp(-lse) + t_min) | ||||||
|         # Calculate per-element survival loss (negative log-likelihood of exponential dist). |         ldt = -torch.log(t + t_min) | ||||||
|         # We add a small epsilon for numerical stability with the log. |         loss_dt = -(lse - torch.exp(lse - ldt)) | ||||||
|         per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) |         loss_survival = loss_dt[mask].mean() | ||||||
|          |  | ||||||
|         # Apply the mask and compute the mean of valid elements. |  | ||||||
|         loss_survival = per_element_survival[mask].mean() |  | ||||||
|  |  | ||||||
|         return loss_ce, loss_survival |         return loss_ce, loss_survival | ||||||
|   | |||||||
							
								
								
									
										160
									
								
								plot_auc_boxplots_by_chapter.R
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								plot_auc_boxplots_by_chapter.R
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,160 @@ | |||||||
|  | # Compare AUC distributions between models by ICD-10 chapter (1-year and no-gap) | ||||||
|  | # Usage: | ||||||
|  | #   Rscript plot_auc_boxplots_by_chapter.R [one_year_csv] [no_gap_csv] [output_dir] | ||||||
|  | # Defaults: | ||||||
|  | #   one_year_csv = "model_comparison_auc_1year.csv" | ||||||
|  | #   no_gap_csv   = "model_comparison_auc_no_gap.csv" | ||||||
|  | #   output_dir   = current working directory (".") | ||||||
|  |  | ||||||
|  | suppressPackageStartupMessages({ | ||||||
|  |   library(ggplot2) | ||||||
|  |   library(cowplot) | ||||||
|  | }) | ||||||
|  |  | ||||||
|  | args <- commandArgs(trailingOnly = TRUE) | ||||||
|  | one_year_csv <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv" | ||||||
|  | no_gap_csv   <- if (length(args) >= 2) args[2] else "model_comparison_auc_no_gap.csv" | ||||||
|  | out_dir      <- if (length(args) >= 3) args[3] else "." | ||||||
|  | orientation  <- if (length(args) >= 4) tolower(args[4]) else "vertical"  # "horizontal" (flipped) or "vertical" | ||||||
|  |  | ||||||
|  | if (!dir.exists(out_dir)) { | ||||||
|  |   dir.create(out_dir, recursive = TRUE, showWarnings = FALSE) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | read_csv_safe <- function(path) { | ||||||
|  |   tryCatch({ | ||||||
|  |     read.csv(path, check.names = FALSE) | ||||||
|  |   }, error = function(e) { | ||||||
|  |     stop(sprintf("Failed to read CSV at '%s': %s", path, e$message)) | ||||||
|  |   }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Determine a chapter column name robustly | ||||||
|  | get_chapter_col <- function(df) { | ||||||
|  |   candidates <- c("ICD-10 Chapter (short)", "ICD-10 Chapter", "ICD10_chapter", "chapter", "ICD_chapter") | ||||||
|  |   for (c in candidates) { | ||||||
|  |     if (c %in% names(df)) return(c) | ||||||
|  |   } | ||||||
|  |   return(NA_character_) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Compute a deterministic chapter ordering using the ICD-10 chapter numeral prefix | ||||||
|  | # e.g., "I. Infectious Diseases", "II. Neoplasms", ..., "XVII. ...", with a fallback for "Death" and unknowns | ||||||
|  | compute_chapter_levels <- function(chapters) { | ||||||
|  |   ch <- as.character(chapters) | ||||||
|  |   roman_levels <- c( | ||||||
|  |     "I","II","III","IV","V","VI","VII","VIII","IX","X", | ||||||
|  |     "XI","XII","XIII","XIV","XV","XVI","XVII","XVIII","XIX","XX" | ||||||
|  |   ) | ||||||
|  |   roman_map <- setNames(seq_along(roman_levels), roman_levels) | ||||||
|  |   # Extract leading Roman numeral before a dot, like "XVI." -> "XVI" | ||||||
|  |   roman <- toupper(gsub("^\\s*([IVXLCDM]+)\\..*$", "\\1", ch)) | ||||||
|  |   idx <- rep(NA_integer_, length(ch)) | ||||||
|  |   hit <- roman %in% names(roman_map) | ||||||
|  |   idx[hit] <- roman_map[roman[hit]] | ||||||
|  |   # Special-case Death at the end | ||||||
|  |   idx[grepl("^\\s*Death\\b", ch, ignore.case = TRUE)] <- 99L | ||||||
|  |   # Unknowns to the very end | ||||||
|  |   idx[is.na(idx)] <- 100L | ||||||
|  |   # Order chapters by idx, stable within same idx by appearance | ||||||
|  |   o <- order(idx, match(ch, unique(ch))) | ||||||
|  |   unique(ch[o]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Build long-format data.frame with columns: chapter, model, auc | ||||||
|  | # It will include any of the known model columns that exist in the input df | ||||||
|  | build_long_df <- function(df) { | ||||||
|  |   model_cols <- c( | ||||||
|  |     auc_120     = "auc_120", | ||||||
|  |     auc_120_l   = "auc_120_l", | ||||||
|  |     auc_256     = "auc_256", | ||||||
|  |     auc_256_l   = "auc_256_l", | ||||||
|  |     auc_delphi  = "auc_delphi" | ||||||
|  |   ) | ||||||
|  |   pretty_names <- c( | ||||||
|  |     auc_120    = "GPT-2 120", | ||||||
|  |     auc_120_l  = "GPT-2 120_L", | ||||||
|  |     auc_256    = "GPT-2 256", | ||||||
|  |     auc_256_l  = "GPT-2 256_L", | ||||||
|  |     auc_delphi = "Delphi" | ||||||
|  |   ) | ||||||
|  |   present <- model_cols[names(model_cols) %in% names(df)] | ||||||
|  |   if (length(present) == 0) stop("No known AUC columns found in input data.") | ||||||
|  |   chap_col <- get_chapter_col(df) | ||||||
|  |   if (is.na(chap_col)) { | ||||||
|  |     warning("No chapter column found; using a single 'All' group.") | ||||||
|  |     chapters <- rep("All", nrow(df)) | ||||||
|  |   } else { | ||||||
|  |     chapters <- df[[chap_col]] | ||||||
|  |   } | ||||||
|  |   out_list <- list() | ||||||
|  |   for (key in names(model_cols)) { | ||||||
|  |     col <- model_cols[[key]] | ||||||
|  |     if (col %in% names(df)) { | ||||||
|  |       out_list[[length(out_list) + 1]] <- data.frame( | ||||||
|  |         chapter = chapters, | ||||||
|  |         model   = pretty_names[[key]], | ||||||
|  |         auc     = as.numeric(df[[col]]), | ||||||
|  |         stringsAsFactors = FALSE | ||||||
|  |       ) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   long_df <- do.call(rbind, out_list) | ||||||
|  |   # Filter out-of-range or NA | ||||||
|  |   long_df <- long_df[is.finite(long_df$auc) & long_df$auc >= 0 & long_df$auc <= 1, ] | ||||||
|  |   long_df$model <- factor(long_df$model, levels = c("GPT-2 120", "GPT-2 120_L", "GPT-2 256", "GPT-2 256_L", "Delphi")) | ||||||
|  |   return(long_df) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Make the boxplot grouped by chapter | ||||||
|  | make_boxplot <- function(long_df, title_text, flip = TRUE) { | ||||||
|  |   # Order chapters by their ICD-10 chapter number prefix (Roman numerals) | ||||||
|  |   chap_levels <- compute_chapter_levels(long_df$chapter) | ||||||
|  |   long_df$chapter <- factor(long_df$chapter, levels = chap_levels) | ||||||
|  |  | ||||||
|  |   p <- ggplot(long_df, aes(x = chapter, y = auc, fill = model)) + | ||||||
|  |     geom_boxplot(outlier.shape = 19, outlier.size = 0.7, width = 0.75, alpha = 0.95) + | ||||||
|  |     scale_y_continuous(limits = c(0.3, 1.0), breaks = seq(0.3, 1.0, by = 0.1)) + | ||||||
|  |     labs(title = title_text, x = "ICD-10 Chapter", y = "AUC") + | ||||||
|  |     theme_minimal(base_size = 11) + | ||||||
|  |     theme( | ||||||
|  |       plot.title = element_text(hjust = 0.5), | ||||||
|  |       panel.grid.minor = element_blank(), | ||||||
|  |       legend.position = "bottom" | ||||||
|  |     ) + | ||||||
|  |     guides(fill = guide_legend(nrow = 1)) | ||||||
|  |   if (flip) { | ||||||
|  |     p <- p + coord_flip() | ||||||
|  |   } else { | ||||||
|  |     # For vertical plots, angle x-axis labels for readability | ||||||
|  |     p <- p + theme(axis.text.x = element_text(angle = 45, hjust = 1)) | ||||||
|  |   } | ||||||
|  |   p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Build plots for 1-year and no-gap | ||||||
|  | one_year_df <- read_csv_safe(one_year_csv) | ||||||
|  | no_gap_df   <- read_csv_safe(no_gap_csv) | ||||||
|  |  | ||||||
|  | one_year_long <- build_long_df(one_year_df) | ||||||
|  | no_gap_long   <- build_long_df(no_gap_df) | ||||||
|  |  | ||||||
|  | flip_flag <- ifelse(orientation %in% c("horizontal", "flip", "flipped"), TRUE, FALSE) | ||||||
|  |  | ||||||
|  | p1 <- make_boxplot(one_year_long, "AUC by ICD-10 Chapter (1-year gap)", flip = flip_flag) | ||||||
|  | p2 <- make_boxplot(no_gap_long,   "AUC by ICD-10 Chapter (no gap)",   flip = flip_flag) | ||||||
|  |  | ||||||
|  | # Save individual plots | ||||||
|  | out_1year <- file.path(out_dir, "auc_boxplot_by_chapter_1year.png") | ||||||
|  | ggsave(out_1year, p1, width = 12, height = 10, dpi = 300, bg = "white") | ||||||
|  | cat(sprintf("Saved: %s\n", out_1year)) | ||||||
|  |  | ||||||
|  | out_nogap <- file.path(out_dir, "auc_boxplot_by_chapter_no_gap.png") | ||||||
|  | ggsave(out_nogap, p2, width = 12, height = 10, dpi = 300, bg = "white") | ||||||
|  | cat(sprintf("Saved: %s\n", out_nogap)) | ||||||
|  |  | ||||||
|  | # Save a side-by-side grid for quick comparison | ||||||
|  | grid <- plot_grid(p1, p2, labels = c("A", "B"), ncol = 2, align = "hv") | ||||||
|  | out_grid <- file.path(out_dir, "auc_boxplot_by_chapter_grid.png") | ||||||
|  | ggsave(out_grid, grid, width = 18, height = 10, dpi = 250, bg = "white") | ||||||
|  | cat(sprintf("Saved grid: %s\n", out_grid)) | ||||||
							
								
								
									
										125
									
								
								plot_model_comparison_1year.R
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								plot_model_comparison_1year.R
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | |||||||
|  | # Plot AUC comparisons (1-year gap) between models and Delphi using ggplot2 | ||||||
|  | # Usage: | ||||||
|  | #   Rscript plot_model_comparison_1year.R [path_to_csv] [output_dir] | ||||||
|  | # Defaults: | ||||||
|  | #   path_to_csv = "model_comparison_auc_1year.csv" | ||||||
|  | #   output_dir  = current working directory (".") | ||||||
|  |  | ||||||
|  | suppressPackageStartupMessages({ | ||||||
|  |   library(ggplot2) | ||||||
|  |   library(cowplot) | ||||||
|  | }) | ||||||
|  |  | ||||||
|  | args <- commandArgs(trailingOnly = TRUE) | ||||||
|  | csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv" | ||||||
|  | out_dir  <- if (length(args) >= 2) args[2] else "." | ||||||
|  |  | ||||||
|  | if (!dir.exists(out_dir)) { | ||||||
|  |   dir.create(out_dir, recursive = TRUE, showWarnings = FALSE) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Read data | ||||||
|  | # Expect columns including: auc_delphi, auc_256, auc_120, Colour (hex color), name, etc. | ||||||
|  | df <- tryCatch({ | ||||||
|  |   read.csv(csv_path, check.names = FALSE) | ||||||
|  | }, error = function(e) { | ||||||
|  |   stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message)) | ||||||
|  | }) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Helper to compare any two AUC columns (x vs y) | ||||||
|  | make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) { | ||||||
|  |   ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) + | ||||||
|  |     geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) + | ||||||
|  |     geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) + | ||||||
|  |     scale_fill_identity() + | ||||||
|  |     coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) + | ||||||
|  |     coord_fixed(ratio = 1) + | ||||||
|  |     labs(title = title_text, x = x_label, y = y_label) + | ||||||
|  |     theme_minimal(base_size = 10) + | ||||||
|  |     theme( | ||||||
|  |       plot.title = element_text(hjust = 0.5), | ||||||
|  |       panel.grid.minor = element_blank() | ||||||
|  |     ) | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Helper to compare model AUC vs Delphi AUC (x = auc_delphi) | ||||||
|  | make_delphi_plot <- function(data, y_col, title_text, y_label) { | ||||||
|  |   ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) + | ||||||
|  |     geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) + | ||||||
|  |     geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) + | ||||||
|  |     scale_fill_identity() + | ||||||
|  |     coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) + | ||||||
|  |     coord_fixed(ratio = 1) + | ||||||
|  |     labs(title = title_text, x = "AUC_Delphi", y = y_label) + | ||||||
|  |     theme_minimal(base_size = 10) + | ||||||
|  |     theme( | ||||||
|  |       plot.title = element_text(hjust = 0.5), | ||||||
|  |       panel.grid.minor = element_blank() | ||||||
|  |     ) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Placeholder empty plot if a required column is missing | ||||||
|  | empty_plot <- function(msg) { | ||||||
|  |   ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Plot: AUC_120 vs AUC_120_L (1 year gap) | ||||||
|  | if (!all(c("auc_120", "auc_120_l") %in% names(df))) { | ||||||
|  |   warning("Columns 'auc_120' and/or 'auc_120_l' not found in CSV; skipping AUC_120 vs AUC_120_L plot.") | ||||||
|  | } else { | ||||||
|  |   p120_vs_120l <- make_xy_plot( | ||||||
|  |     data = df, | ||||||
|  |     x_col = "auc_120", | ||||||
|  |     y_col = "auc_120_l", | ||||||
|  |     title_text = "AUC_120 vs AUC_120_L 1 year gap", | ||||||
|  |     x_label = "AUC_120", | ||||||
|  |     y_label = "AUC_120_L" | ||||||
|  |   ) | ||||||
|  |   out_120_vs_120l <- file.path(out_dir, "model_comparison_auc_120_vs_120_l_1year.png") | ||||||
|  |   ggsave(filename = out_120_vs_120l, plot = p120_vs_120l, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_120_vs_120l)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Plot: AUC_256 vs AUC_256_L (1 year gap) | ||||||
|  | if (!all(c("auc_256", "auc_256_l") %in% names(df))) { | ||||||
|  |   warning("Columns 'auc_256' and/or 'auc_256_l' not found in CSV; skipping AUC_256 vs AUC_256_L plot.") | ||||||
|  | } else { | ||||||
|  |   p256_vs_256l <- make_xy_plot( | ||||||
|  |     data = df, | ||||||
|  |     x_col = "auc_256", | ||||||
|  |     y_col = "auc_256_l", | ||||||
|  |     title_text = "AUC_256 vs AUC_256_L 1 year gap", | ||||||
|  |     x_label = "AUC_256", | ||||||
|  |     y_label = "AUC_256_L" | ||||||
|  |   ) | ||||||
|  |   out_256_vs_256l <- file.path(out_dir, "model_comparison_auc_256_vs_256_l_1year.png") | ||||||
|  |   ggsave(filename = out_256_vs_256l, plot = p256_vs_256l, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_256_vs_256l)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # ---- Combined 2x2 grid: (auc_120 vs delphi), (auc_256 vs delphi), (auc_120_l vs delphi), (auc_256_l vs delphi) ---- | ||||||
|  |  | ||||||
|  | has_cols <- function(cols) all(cols %in% names(df)) | ||||||
|  |  | ||||||
|  | p_120_vs_delphi   <- if (has_cols(c("auc_delphi", "auc_120")))   make_delphi_plot(df, "auc_120",   "AUC_120 vs Delphi (1 year)",   "AUC_120")   else empty_plot("Missing auc_120 or auc_delphi") | ||||||
|  | p_256_vs_delphi   <- if (has_cols(c("auc_delphi", "auc_256")))   make_delphi_plot(df, "auc_256",   "AUC_256 vs Delphi (1 year)",   "AUC_256")   else empty_plot("Missing auc_256 or auc_delphi") | ||||||
|  | p_120l_vs_delphi  <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (1 year)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi") | ||||||
|  | p_256l_vs_delphi  <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (1 year)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi") | ||||||
|  |  | ||||||
|  | grid_plot <- plot_grid( | ||||||
|  |   p_120_vs_delphi, p_256_vs_delphi, | ||||||
|  |   p_120l_vs_delphi, p_256l_vs_delphi, | ||||||
|  |   labels = c("A", "B", "C", "D"), | ||||||
|  |   ncol = 2, align = "hv" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_1year_grid.png") | ||||||
|  | ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white") | ||||||
|  | cat(sprintf("Saved grid: %s\n", out_grid)) | ||||||
							
								
								
									
										128
									
								
								plot_model_comparison_no_gap.R
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								plot_model_comparison_no_gap.R
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | |||||||
|  | # Plot AUC comparisons (no gap) between models and Delphi using ggplot2 | ||||||
|  | # Usage: | ||||||
|  | #   Rscript plot_model_comparison_no_gap.R [path_to_csv] [output_dir] | ||||||
|  | # Defaults: | ||||||
|  | #   path_to_csv = "model_comparison_auc_no_gap.csv" | ||||||
|  | #   output_dir  = current working directory (".") | ||||||
|  |  | ||||||
|  | suppressPackageStartupMessages({ | ||||||
|  |   library(ggplot2) | ||||||
|  |   library(cowplot) | ||||||
|  | }) | ||||||
|  |  | ||||||
|  | args <- commandArgs(trailingOnly = TRUE) | ||||||
|  | csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_no_gap.csv" | ||||||
|  | out_dir  <- if (length(args) >= 2) args[2] else "." | ||||||
|  |  | ||||||
|  | if (!dir.exists(out_dir)) { | ||||||
|  |   dir.create(out_dir, recursive = TRUE, showWarnings = FALSE) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Read data | ||||||
|  | # Expect columns including: auc_delphi, auc_256, auc_120, auc_256_l, auc_120_l, Colour (hex color), name, etc. | ||||||
|  | df <- tryCatch({ | ||||||
|  |   read.csv(csv_path, check.names = FALSE) | ||||||
|  | }, error = function(e) { | ||||||
|  |   stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message)) | ||||||
|  | }) | ||||||
|  |  | ||||||
|  | # Helper to compare any two AUC columns (x vs y) | ||||||
|  | make_xy_plot <- function(data, x_col, y_col, title_text, x_label, y_label) { | ||||||
|  |   ggplot(data, aes(x = .data[[x_col]], y = .data[[y_col]])) + | ||||||
|  |     geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) + | ||||||
|  |     geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) + | ||||||
|  |     scale_fill_identity() + | ||||||
|  |     coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) + | ||||||
|  |     coord_fixed(ratio = 1) + | ||||||
|  |     labs(title = title_text, x = x_label, y = y_label) + | ||||||
|  |     theme_minimal(base_size = 10) + | ||||||
|  |     theme( | ||||||
|  |       plot.title = element_text(hjust = 0.5), | ||||||
|  |       panel.grid.minor = element_blank() | ||||||
|  |     ) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Helper to compare model AUC vs Delphi AUC (x = auc_delphi) | ||||||
|  | make_delphi_plot <- function(data, y_col, title_text, y_label) { | ||||||
|  |   ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) + | ||||||
|  |     geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) + | ||||||
|  |     geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) + | ||||||
|  |     geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) + | ||||||
|  |     scale_fill_identity() + | ||||||
|  |     coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) + | ||||||
|  |     coord_fixed(ratio = 1) + | ||||||
|  |     labs(title = title_text, x = "AUC_Delphi", y = y_label) + | ||||||
|  |     theme_minimal(base_size = 10) + | ||||||
|  |     theme( | ||||||
|  |       plot.title = element_text(hjust = 0.5), | ||||||
|  |       panel.grid.minor = element_blank() | ||||||
|  |     ) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Placeholder empty plot if a required column is missing | ||||||
|  | empty_plot <- function(msg) { | ||||||
|  |   ggplot() + theme_void() + ggtitle(msg) + theme(plot.title = element_text(hjust = 0.5)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # Individual Delphi comparison plots | ||||||
|  | has_cols <- function(cols) all(cols %in% names(df)) | ||||||
|  |  | ||||||
|  | # AUC_120 vs AUC_Delphi (no gap) | ||||||
|  | if (has_cols(c("auc_delphi", "auc_120"))) { | ||||||
|  |   p120 <- make_delphi_plot(df, "auc_120", "AUC_120 vs AUC_Delphi (no gap)", "AUC_120") | ||||||
|  |   out_120 <- file.path(out_dir, "fig_auc_120_vs_delphi_no_gap.png") | ||||||
|  |   ggsave(filename = out_120, plot = p120, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_120)) | ||||||
|  | } else { | ||||||
|  |   warning("Missing columns for AUC_120 vs Delphi plot.") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # AUC_256 vs AUC_Delphi (no gap) | ||||||
|  | if (has_cols(c("auc_delphi", "auc_256"))) { | ||||||
|  |   p256 <- make_delphi_plot(df, "auc_256", "AUC_256 vs AUC_Delphi (no gap)", "AUC_256") | ||||||
|  |   out_256 <- file.path(out_dir, "model_comparison_auc_256_vs_delphi_no_gap.png") | ||||||
|  |   ggsave(filename = out_256, plot = p256, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_256)) | ||||||
|  | } else { | ||||||
|  |   warning("Missing columns for AUC_256 vs Delphi plot.") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # AUC_120_L vs AUC_Delphi (no gap) | ||||||
|  | if (has_cols(c("auc_delphi", "auc_120_l"))) { | ||||||
|  |   p120l <- make_delphi_plot(df, "auc_120_l", "AUC_120_L vs AUC_Delphi (no gap)", "AUC_120_L") | ||||||
|  |   out_120l <- file.path(out_dir, "fig_auc_120_l_vs_delphi_no_gap.png") | ||||||
|  |   ggsave(filename = out_120l, plot = p120l, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_120l)) | ||||||
|  | } else { | ||||||
|  |   warning("Missing columns for AUC_120_L vs Delphi plot.") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # AUC_256_L vs AUC_Delphi (no gap) | ||||||
|  | if (has_cols(c("auc_delphi", "auc_256_l"))) { | ||||||
|  |   p256l <- make_delphi_plot(df, "auc_256_l", "AUC_256_L vs AUC_Delphi (no gap)", "AUC_256_L") | ||||||
|  |   out_256l <- file.path(out_dir, "model_comparison_auc_256_l_vs_delphi_no_gap.png") | ||||||
|  |   ggsave(filename = out_256l, plot = p256l, width = 7, height = 4, dpi = 600, bg = "white") | ||||||
|  |   cat(sprintf("Saved: %s\n", out_256l)) | ||||||
|  | } else { | ||||||
|  |   warning("Missing columns for AUC_256_L vs Delphi plot.") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | # 2x2 grid of Delphi comparisons | ||||||
|  | p_120_vs_delphi   <- if (has_cols(c("auc_delphi", "auc_120")))   make_delphi_plot(df, "auc_120",   "AUC_120 vs Delphi (no gap)",   "AUC_120")   else empty_plot("Missing auc_120 or auc_delphi") | ||||||
|  | p_256_vs_delphi   <- if (has_cols(c("auc_delphi", "auc_256")))   make_delphi_plot(df, "auc_256",   "AUC_256 vs Delphi (no gap)",   "AUC_256")   else empty_plot("Missing auc_256 or auc_delphi") | ||||||
|  | p_120l_vs_delphi  <- if (has_cols(c("auc_delphi", "auc_120_l"))) make_delphi_plot(df, "auc_120_l", "AUC_120_L vs Delphi (no gap)", "AUC_120_L") else empty_plot("Missing auc_120_l or auc_delphi") | ||||||
|  | p_256l_vs_delphi  <- if (has_cols(c("auc_delphi", "auc_256_l"))) make_delphi_plot(df, "auc_256_l", "AUC_256_L vs Delphi (no gap)", "AUC_256_L") else empty_plot("Missing auc_256_l or auc_delphi") | ||||||
|  |  | ||||||
|  | grid_plot <- plot_grid( | ||||||
|  |   p_120_vs_delphi, p_256_vs_delphi, | ||||||
|  |   p_120l_vs_delphi, p_256l_vs_delphi, | ||||||
|  |   labels = c("A", "B", "C", "D"), | ||||||
|  |   ncol = 2, align = "hv" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | out_grid <- file.path(out_dir, "model_comparison_auc_vs_delphi_no_gap_grid.png") | ||||||
|  | ggsave(filename = out_grid, plot = grid_plot, width = 12, height = 8, dpi = 300, bg = "white") | ||||||
|  | cat(sprintf("Saved grid: %s\n", out_grid)) | ||||||
							
								
								
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | torch | ||||||
|  | numpy | ||||||
|  | tqdm | ||||||
|  | matplotlib | ||||||
|  | joblib | ||||||
							
								
								
									
										91
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								train.py
									
									
									
									
									
								
							| @@ -1,13 +1,15 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from torch.optim import Adam | from torch.optim import AdamW | ||||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||||
| import numpy as np | import numpy as np | ||||||
| import math | import math | ||||||
| import tqdm | import tqdm | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  | import json | ||||||
|  | import argparse | ||||||
|  |  | ||||||
| from models import TimeAwareGPT2, CombinedLoss | from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv, CombinedLoss | ||||||
| from utils import PatientEventDataset | from utils import PatientEventDataset | ||||||
|  |  | ||||||
| # --- Configuration --- | # --- Configuration --- | ||||||
| @@ -15,33 +17,80 @@ class TrainConfig: | |||||||
|     # Data parameters |     # Data parameters | ||||||
|     train_data_path = 'ukb_real_train.bin' |     train_data_path = 'ukb_real_train.bin' | ||||||
|     val_data_path = 'ukb_real_val.bin' |     val_data_path = 'ukb_real_val.bin' | ||||||
|     block_length = 24  # Sequence length |     block_length = 48  # Sequence length | ||||||
|  |  | ||||||
|     # Model parameters |     # Model parameters | ||||||
|     n_embd = 256 |     n_embd = 120 | ||||||
|     n_layer = 8 |     n_layer = 12 | ||||||
|     n_head = 8 |     n_head = 12 | ||||||
|     pdrop = 0.1 |     pdrop = 0.1 | ||||||
|     token_pdrop = 0.1 |     token_pdrop = 0.1 | ||||||
|  |     model_name = 'TimeAwareGPT2' | ||||||
|  |  | ||||||
|     # Training parameters |     # Training parameters | ||||||
|     max_epoch = 200 |     max_epoch = 200 | ||||||
|     batch_size = 128 |     batch_size = 128 | ||||||
|     lr_initial = 6e-4 |     lr_initial = 6e-4 | ||||||
|     lr_final = 6e-5 |     lr_final = 6e-5 | ||||||
|  |     weight_decay = 2e-1 | ||||||
|     warmup_epochs = 10 |     warmup_epochs = 10 | ||||||
|     early_stopping_patience = 5 |     early_stopping_patience = 10 | ||||||
|  |     betas = (0.9, 0.99) | ||||||
|      |      | ||||||
|     # Loss parameters |     # Loss parameters | ||||||
|     # 0 = padding, 1 = "no event" |     # 0 = padding, 1 = "no event" | ||||||
|     ignored_token_ids = [0, 1] |     ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]  # Example ignored token IDs | ||||||
|  |  | ||||||
|     # System parameters |     # System parameters | ||||||
|     device = 'cuda' if torch.cuda.is_available() else 'cpu' |     device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||||
|  |  | ||||||
| # --- Main Training Script --- | # --- Main Training Script --- | ||||||
| def main(): | def main(): | ||||||
|  |     parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model.') | ||||||
|  |     parser.add_argument('--n_layer', type=int, default=12, help='Number of transformer layers.') | ||||||
|  |     parser.add_argument('--n_embd', type=int, default=120, help='Embedding dimension.') | ||||||
|  |     parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads.') | ||||||
|  |     parser.add_argument('--max_epoch', type=int, default=200, help='Maximum number of training epochs.') | ||||||
|  |     parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.') | ||||||
|  |     parser.add_argument('--lr_initial', type=float, default=6e-4, help='Initial learning rate.') | ||||||
|  |     parser.add_argument('--lr_final', type=float, default=6e-5, help='Final learning rate.') | ||||||
|  |     parser.add_argument('--weight_decay', type=float, default=2e-1, help='Weight decay for the optimizer.') | ||||||
|  |     parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs.') | ||||||
|  |     parser.add_argument('--early_stopping_patience', type=int, default=10, help='Patience for early stopping.') | ||||||
|  |     parser.add_argument('--pdrop', type=float, default=0.1, help='Dropout probability.') | ||||||
|  |     parser.add_argument('--token_pdrop', type=float, default=0.1, help='Token dropout probability.') | ||||||
|  |     parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99], help='AdamW betas.') | ||||||
|  |     parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable', 'TimeAwareGPT2TemporalConv'], default='TimeAwareGPT2', help='Model architecture to train.') | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     config = TrainConfig() |     config = TrainConfig() | ||||||
|  |     config.n_layer = args.n_layer | ||||||
|  |     config.n_embd = args.n_embd | ||||||
|  |     config.n_head = args.n_head | ||||||
|  |     config.max_epoch = args.max_epoch | ||||||
|  |     config.batch_size = args.batch_size | ||||||
|  |     config.lr_initial = args.lr_initial | ||||||
|  |     config.lr_final = args.lr_final | ||||||
|  |     config.weight_decay = args.weight_decay | ||||||
|  |     config.warmup_epochs = args.warmup_epochs | ||||||
|  |     config.early_stopping_patience = args.early_stopping_patience | ||||||
|  |     config.pdrop = args.pdrop | ||||||
|  |     config.token_pdrop = args.token_pdrop | ||||||
|  |     config.betas = tuple(args.betas) | ||||||
|  |     config.model_name = args.model | ||||||
|  |  | ||||||
|  |     model_suffix = f"{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}" | ||||||
|  |     model_filename = f"best_model_{model_suffix}.pt" | ||||||
|  |     checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" | ||||||
|  |  | ||||||
|  |     # --- 0. Save Configuration --- | ||||||
|  |     # Include model class in config filename for clarity/distinction across architectures | ||||||
|  |     config_filename = f"config_{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" | ||||||
|  |     config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} | ||||||
|  |     with open(config_filename, 'w') as f: | ||||||
|  |         json.dump(config_dict, f, indent=4) | ||||||
|  |     print(f"Configuration saved to {config_filename}") | ||||||
|  |  | ||||||
|     # --- 1. Data Loading --- |     # --- 1. Data Loading --- | ||||||
|     print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") |     print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") | ||||||
| @@ -60,7 +109,13 @@ def main(): | |||||||
|  |  | ||||||
|     # --- 2. Model, Optimizer, and Loss Initialization --- |     # --- 2. Model, Optimizer, and Loss Initialization --- | ||||||
|     print(f"Initializing model on {config.device}...") |     print(f"Initializing model on {config.device}...") | ||||||
|     model = TimeAwareGPT2( |     model_cls = { | ||||||
|  |         'TimeAwareGPT2': TimeAwareGPT2, | ||||||
|  |         'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, | ||||||
|  |         'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv, | ||||||
|  |     }[config.model_name] | ||||||
|  |  | ||||||
|  |     model = model_cls( | ||||||
|         vocab_size=vocab_size, |         vocab_size=vocab_size, | ||||||
|         n_embd=config.n_embd, |         n_embd=config.n_embd, | ||||||
|         n_layer=config.n_layer, |         n_layer=config.n_layer, | ||||||
| @@ -72,7 +127,7 @@ def main(): | |||||||
|     print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") |     print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") | ||||||
|  |  | ||||||
|     loss_fn = CombinedLoss(config.ignored_token_ids) |     loss_fn = CombinedLoss(config.ignored_token_ids) | ||||||
|     optimizer = Adam(model.parameters(), lr=config.lr_initial) |     optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=config.betas) | ||||||
|  |  | ||||||
|     # --- 3. Training Loop --- |     # --- 3. Training Loop --- | ||||||
|     best_val_loss = float('inf') |     best_val_loss = float('inf') | ||||||
| @@ -170,7 +225,7 @@ def main(): | |||||||
|             best_val_loss = total_val_loss |             best_val_loss = total_val_loss | ||||||
|             patience_counter = 0 |             patience_counter = 0 | ||||||
|             print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") |             print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") | ||||||
|             torch.save(model.state_dict(), 'best_model_checkpoint.pt') |             torch.save(model.state_dict(), checkpoint_filename) | ||||||
|         else: |         else: | ||||||
|             if epoch >= config.warmup_epochs: |             if epoch >= config.warmup_epochs: | ||||||
|                 patience_counter += 1 |                 patience_counter += 1 | ||||||
| @@ -183,12 +238,20 @@ def main(): | |||||||
|     # --- Save Best Model at the End --- |     # --- Save Best Model at the End --- | ||||||
|     if best_val_loss != float('inf'): |     if best_val_loss != float('inf'): | ||||||
|         print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") |         print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") | ||||||
|         model.load_state_dict(torch.load('best_model_checkpoint.pt')) |         model.load_state_dict(torch.load(checkpoint_filename)) | ||||||
|         print("Saving final best model to best_model.pt") |         print(f"Saving final best model to {model_filename}") | ||||||
|         torch.save(model.state_dict(), 'best_model.pt') |         torch.save(model.state_dict(), model_filename) | ||||||
|     else: |     else: | ||||||
|         print("\nTraining finished. No best model to save as validation loss never improved.") |         print("\nTraining finished. No best model to save as validation loss never improved.") | ||||||
|  |  | ||||||
|  |     # --- Save losses to a txt file --- | ||||||
|  |     losses_filename = f"losses_{model_suffix}.txt" | ||||||
|  |     with open(losses_filename, 'w') as f: | ||||||
|  |         f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") | ||||||
|  |         for i in range(len(train_losses_total)): | ||||||
|  |             f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") | ||||||
|  |     print(f"\nLosses saved to {losses_filename}") | ||||||
|  |  | ||||||
|     # --- Plot and Save Loss Curves --- |     # --- Plot and Save Loss Curves --- | ||||||
|     num_epochs = len(train_losses_total) |     num_epochs = len(train_losses_total) | ||||||
|     epochs = range(1, num_epochs + 1) |     epochs = range(1, num_epochs + 1) | ||||||
|   | |||||||
							
								
								
									
										364
									
								
								train_ddp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										364
									
								
								train_ddp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,364 @@ | |||||||
|  | import os | ||||||
|  | import json | ||||||
|  | import math | ||||||
|  | import argparse | ||||||
|  | from typing import Tuple | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.distributed as dist | ||||||
|  | from torch.optim import AdamW | ||||||
|  | from torch.nn.parallel import DistributedDataParallel as DDP | ||||||
|  | from torch.utils.data import DataLoader, DistributedSampler | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import tqdm | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
|  | from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss | ||||||
|  | from utils import PatientEventDataset | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TrainConfig: | ||||||
|  |     # Data parameters | ||||||
|  |     train_data_path = 'ukb_real_train.bin' | ||||||
|  |     val_data_path = 'ukb_real_val.bin' | ||||||
|  |     block_length = 48 | ||||||
|  |  | ||||||
|  |     # Model parameters | ||||||
|  |     n_embd = 120 | ||||||
|  |     n_layer = 12 | ||||||
|  |     n_head = 12 | ||||||
|  |     pdrop = 0.1 | ||||||
|  |     token_pdrop = 0.1 | ||||||
|  |     model_name = 'TimeAwareGPT2' | ||||||
|  |  | ||||||
|  |     # Training parameters | ||||||
|  |     max_epoch = 200 | ||||||
|  |     batch_size = 128 | ||||||
|  |     lr_initial = 6e-4 | ||||||
|  |     lr_final = 6e-5 | ||||||
|  |     weight_decay = 2e-1 | ||||||
|  |     warmup_epochs = 10 | ||||||
|  |     early_stopping_patience = 10 | ||||||
|  |     betas = (0.9, 0.99) | ||||||
|  |  | ||||||
|  |     # Loss parameters (ignored tokens) | ||||||
|  |     ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def setup_ddp(backend: str | None = None): | ||||||
|  |     """Initialize torch.distributed from environment variables set by torchrun.""" | ||||||
|  |     if backend is None: | ||||||
|  |         if torch.cuda.is_available() and os.name != 'nt': | ||||||
|  |             backend = 'nccl' | ||||||
|  |         else: | ||||||
|  |             backend = 'gloo' | ||||||
|  |     dist.init_process_group(backend=backend) | ||||||
|  |  | ||||||
|  |     local_rank = int(os.environ.get('LOCAL_RANK', 0)) | ||||||
|  |     rank = int(os.environ.get('RANK', 0)) | ||||||
|  |     world_size = int(os.environ.get('WORLD_SIZE', 1)) | ||||||
|  |  | ||||||
|  |     if torch.cuda.is_available(): | ||||||
|  |         torch.cuda.set_device(local_rank) | ||||||
|  |         device = torch.device(f'cuda:{local_rank}') | ||||||
|  |     else: | ||||||
|  |         device = torch.device('cpu') | ||||||
|  |  | ||||||
|  |     return rank, world_size, local_rank, device | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def cleanup_ddp(): | ||||||
|  |     if dist.is_initialized(): | ||||||
|  |         dist.destroy_process_group() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def cosine_lr(epoch: int, cfg: TrainConfig) -> float: | ||||||
|  |     if epoch < cfg.warmup_epochs: | ||||||
|  |         return cfg.lr_initial | ||||||
|  |     progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs)) | ||||||
|  |     return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor: | ||||||
|  |     """All-reduce sum then divide by world_size.""" | ||||||
|  |     value = value.clone().to(torch.float64) | ||||||
|  |     dist.all_reduce(value, op=dist.ReduceOp.SUM) | ||||||
|  |     value /= world_size | ||||||
|  |     return value.to(torch.float32) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.') | ||||||
|  |     parser.add_argument('--n_layer', type=int, default=12) | ||||||
|  |     parser.add_argument('--n_embd', type=int, default=120) | ||||||
|  |     parser.add_argument('--n_head', type=int, default=12) | ||||||
|  |     parser.add_argument('--max_epoch', type=int, default=200) | ||||||
|  |     parser.add_argument('--batch_size', type=int, default=128) | ||||||
|  |     parser.add_argument('--lr_initial', type=float, default=6e-4) | ||||||
|  |     parser.add_argument('--lr_final', type=float, default=6e-5) | ||||||
|  |     parser.add_argument('--weight_decay', type=float, default=2e-1) | ||||||
|  |     parser.add_argument('--warmup_epochs', type=int, default=10) | ||||||
|  |     parser.add_argument('--early_stopping_patience', type=int, default=10) | ||||||
|  |     parser.add_argument('--pdrop', type=float, default=0.1) | ||||||
|  |     parser.add_argument('--token_pdrop', type=float, default=0.1) | ||||||
|  |     parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99]) | ||||||
|  |     parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2') | ||||||
|  |     parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.') | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     rank, world_size, local_rank, device = setup_ddp(args.backend) | ||||||
|  |  | ||||||
|  |     # Build config | ||||||
|  |     cfg = TrainConfig() | ||||||
|  |     cfg.n_layer = args.n_layer | ||||||
|  |     cfg.n_embd = args.n_embd | ||||||
|  |     cfg.n_head = args.n_head | ||||||
|  |     cfg.max_epoch = args.max_epoch | ||||||
|  |     cfg.batch_size = args.batch_size | ||||||
|  |     cfg.lr_initial = args.lr_initial | ||||||
|  |     cfg.lr_final = args.lr_final | ||||||
|  |     cfg.weight_decay = args.weight_decay | ||||||
|  |     cfg.warmup_epochs = args.warmup_epochs | ||||||
|  |     cfg.early_stopping_patience = args.early_stopping_patience | ||||||
|  |     cfg.pdrop = args.pdrop | ||||||
|  |     cfg.token_pdrop = args.token_pdrop | ||||||
|  |     cfg.betas = tuple(args.betas) | ||||||
|  |     cfg.model_name = args.model | ||||||
|  |  | ||||||
|  |     # Filenames (shared across ranks) | ||||||
|  |     model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}" | ||||||
|  |     model_filename = f"best_model_{model_suffix}.pt" | ||||||
|  |     checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" | ||||||
|  |     config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json" | ||||||
|  |  | ||||||
|  |     # Save config only on rank 0 | ||||||
|  |     if rank == 0: | ||||||
|  |         with open(config_filename, 'w') as f: | ||||||
|  |             json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4) | ||||||
|  |         print(f"[rank 0] Configuration saved to {config_filename}") | ||||||
|  |  | ||||||
|  |     # Load data (all ranks) | ||||||
|  |     if rank == 0: | ||||||
|  |         print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...") | ||||||
|  |     train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) | ||||||
|  |     val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) | ||||||
|  |  | ||||||
|  |     vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 | ||||||
|  |     if rank == 0: | ||||||
|  |         print(f"Inferred vocabulary size: {vocab_size}") | ||||||
|  |  | ||||||
|  |     train_dataset = PatientEventDataset(train_data_arr, cfg.block_length) | ||||||
|  |     val_dataset = PatientEventDataset(val_data_arr, cfg.block_length) | ||||||
|  |  | ||||||
|  |     train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) | ||||||
|  |     val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) | ||||||
|  |  | ||||||
|  |     train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) | ||||||
|  |     val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) | ||||||
|  |  | ||||||
|  |     # Model, loss, optimizer | ||||||
|  |     model_cls = { | ||||||
|  |         'TimeAwareGPT2': TimeAwareGPT2, | ||||||
|  |         'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, | ||||||
|  |     }[cfg.model_name] | ||||||
|  |  | ||||||
|  |     model = model_cls( | ||||||
|  |         vocab_size=vocab_size, | ||||||
|  |         n_embd=cfg.n_embd, | ||||||
|  |         n_layer=cfg.n_layer, | ||||||
|  |         n_head=cfg.n_head, | ||||||
|  |         pdrop=cfg.pdrop, | ||||||
|  |         token_pdrop=cfg.token_pdrop, | ||||||
|  |     ).to(device) | ||||||
|  |  | ||||||
|  |     ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None) | ||||||
|  |  | ||||||
|  |     loss_fn = CombinedLoss(cfg.ignored_token_ids) | ||||||
|  |     optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas) | ||||||
|  |  | ||||||
|  |     best_val_loss = float('inf') | ||||||
|  |     patience_counter = 0 | ||||||
|  |  | ||||||
|  |     train_losses_ce, train_losses_surv, train_losses_total = [], [], [] | ||||||
|  |     val_losses_ce, val_losses_surv, val_losses_total = [], [], [] | ||||||
|  |  | ||||||
|  |     if rank == 0: | ||||||
|  |         print("Starting DDP training...") | ||||||
|  |  | ||||||
|  |     for epoch in range(cfg.max_epoch): | ||||||
|  |         # Update sampler epoch for shuffling | ||||||
|  |         train_sampler.set_epoch(epoch) | ||||||
|  |         val_sampler.set_epoch(epoch) | ||||||
|  |  | ||||||
|  |         # Set LR | ||||||
|  |         lr = cosine_lr(epoch, cfg) | ||||||
|  |         for pg in optimizer.param_groups: | ||||||
|  |             pg['lr'] = lr | ||||||
|  |  | ||||||
|  |         # Train | ||||||
|  |         ddp_model.train() | ||||||
|  |         train_loss_ce_acc = torch.zeros(1, device=device) | ||||||
|  |         train_loss_surv_acc = torch.zeros(1, device=device) | ||||||
|  |         train_steps = 0 | ||||||
|  |  | ||||||
|  |         pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0)) | ||||||
|  |         for event_seq, time_seq in pbar: | ||||||
|  |             event_seq = event_seq.to(device, non_blocking=True) | ||||||
|  |             time_seq = time_seq.to(device, non_blocking=True) | ||||||
|  |  | ||||||
|  |             input_events = event_seq[:, :-1] | ||||||
|  |             input_times = time_seq[:, :-1] | ||||||
|  |             target_events = event_seq[:, 1:] | ||||||
|  |             target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() | ||||||
|  |  | ||||||
|  |             logits = ddp_model(input_events, input_times) | ||||||
|  |             loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) | ||||||
|  |             loss = loss_ce + loss_survival | ||||||
|  |  | ||||||
|  |             optimizer.zero_grad(set_to_none=True) | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |  | ||||||
|  |             train_loss_ce_acc += loss_ce.detach() | ||||||
|  |             train_loss_surv_acc += loss_survival.detach() | ||||||
|  |             train_steps += 1 | ||||||
|  |  | ||||||
|  |             if rank == 0: | ||||||
|  |                 pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) | ||||||
|  |  | ||||||
|  |         # Aggregate train losses across ranks | ||||||
|  |         if train_steps == 0: | ||||||
|  |             train_steps = 1 | ||||||
|  |         steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64) | ||||||
|  |         dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM) | ||||||
|  |         train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size) | ||||||
|  |         train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size) | ||||||
|  |  | ||||||
|  |         if rank == 0: | ||||||
|  |             train_losses_ce.append(train_loss_ce_mean.item()) | ||||||
|  |             train_losses_surv.append(train_loss_surv_mean.item()) | ||||||
|  |             train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item()) | ||||||
|  |  | ||||||
|  |         # Validation | ||||||
|  |         ddp_model.eval() | ||||||
|  |         val_loss_ce_acc = torch.zeros(1, device=device) | ||||||
|  |         val_loss_surv_acc = torch.zeros(1, device=device) | ||||||
|  |         val_steps = 0 | ||||||
|  |  | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0)) | ||||||
|  |             for event_seq, time_seq in pbar_val: | ||||||
|  |                 event_seq = event_seq.to(device, non_blocking=True) | ||||||
|  |                 time_seq = time_seq.to(device, non_blocking=True) | ||||||
|  |  | ||||||
|  |                 input_events = event_seq[:, :-1] | ||||||
|  |                 input_times = time_seq[:, :-1] | ||||||
|  |                 target_events = event_seq[:, 1:] | ||||||
|  |                 target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() | ||||||
|  |  | ||||||
|  |                 logits = ddp_model(input_events, input_times) | ||||||
|  |                 loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) | ||||||
|  |  | ||||||
|  |                 val_loss_ce_acc += loss_ce.detach() | ||||||
|  |                 val_loss_surv_acc += loss_survival.detach() | ||||||
|  |                 val_steps += 1 | ||||||
|  |  | ||||||
|  |         if val_steps == 0: | ||||||
|  |             val_steps = 1 | ||||||
|  |         vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64) | ||||||
|  |         dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM) | ||||||
|  |         val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size) | ||||||
|  |         val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size) | ||||||
|  |         total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item() | ||||||
|  |  | ||||||
|  |         if rank == 0: | ||||||
|  |             val_losses_ce.append(val_loss_ce_mean.item()) | ||||||
|  |             val_losses_surv.append(val_loss_surv_mean.item()) | ||||||
|  |             val_losses_total.append(total_val_loss) | ||||||
|  |  | ||||||
|  |             print( | ||||||
|  |                 f"Epoch {epoch+1} Summary:\n" | ||||||
|  |                 f"  Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n" | ||||||
|  |                 f"  Val Loss:   {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n" | ||||||
|  |                 f"  Learning Rate: {lr:.6f}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             # Early stopping on rank 0; broadcast decision | ||||||
|  |             improved = total_val_loss < best_val_loss | ||||||
|  |             if improved: | ||||||
|  |                 best_val_loss = total_val_loss | ||||||
|  |                 patience_counter = 0 | ||||||
|  |                 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") | ||||||
|  |                 torch.save(ddp_model.module.state_dict(), checkpoint_filename) | ||||||
|  |             else: | ||||||
|  |                 if epoch >= cfg.warmup_epochs: | ||||||
|  |                     patience_counter += 1 | ||||||
|  |                     print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}") | ||||||
|  |  | ||||||
|  |             stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device) | ||||||
|  |         else: | ||||||
|  |             stop_flag = torch.zeros(1, device=device) | ||||||
|  |  | ||||||
|  |         # Broadcast stop flag and best loss to all ranks | ||||||
|  |         dist.broadcast(stop_flag, src=0) | ||||||
|  |         if stop_flag.item() > 0: | ||||||
|  |             if rank == 0: | ||||||
|  |                 print("\nEarly stopping triggered due to no improvement in validation loss.") | ||||||
|  |             break | ||||||
|  |  | ||||||
|  |     # Save best model at the end (rank 0) | ||||||
|  |     if rank == 0 and best_val_loss != float('inf'): | ||||||
|  |         print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") | ||||||
|  |         state = torch.load(checkpoint_filename, map_location='cpu') | ||||||
|  |         ddp_model.module.load_state_dict(state) | ||||||
|  |         print(f"Saving final best model to {model_filename}") | ||||||
|  |         torch.save(ddp_model.module.state_dict(), model_filename) | ||||||
|  |  | ||||||
|  |         # Save losses to file | ||||||
|  |         losses_filename = f"losses_{model_suffix}.txt" | ||||||
|  |         with open(losses_filename, 'w') as f: | ||||||
|  |             f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") | ||||||
|  |             for i in range(len(train_losses_total)): | ||||||
|  |                 f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") | ||||||
|  |         print(f"\nLosses saved to {losses_filename}") | ||||||
|  |  | ||||||
|  |         # Plot curves | ||||||
|  |         num_epochs = len(train_losses_total) | ||||||
|  |         epochs = range(1, num_epochs + 1) | ||||||
|  |         plt.figure(figsize=(18, 5)) | ||||||
|  |  | ||||||
|  |         plt.subplot(1, 3, 1) | ||||||
|  |         plt.plot(epochs, train_losses_ce, label='Train CE') | ||||||
|  |         plt.plot(epochs, val_losses_ce, label='Val CE') | ||||||
|  |         plt.title('Cross-Entropy Loss') | ||||||
|  |         plt.xlabel('Epochs') | ||||||
|  |         plt.ylabel('Loss') | ||||||
|  |         plt.legend(); plt.grid(True) | ||||||
|  |  | ||||||
|  |         plt.subplot(1, 3, 2) | ||||||
|  |         plt.plot(epochs, train_losses_surv, label='Train Survival') | ||||||
|  |         plt.plot(epochs, val_losses_surv, label='Val Survival') | ||||||
|  |         plt.title('Survival Loss') | ||||||
|  |         plt.xlabel('Epochs') | ||||||
|  |         plt.ylabel('Loss') | ||||||
|  |         plt.legend(); plt.grid(True) | ||||||
|  |  | ||||||
|  |         plt.subplot(1, 3, 3) | ||||||
|  |         plt.plot(epochs, train_losses_total, label='Train Total') | ||||||
|  |         plt.plot(epochs, val_losses_total, label='Val Total') | ||||||
|  |         plt.title('Total Loss') | ||||||
|  |         plt.xlabel('Epochs') | ||||||
|  |         plt.ylabel('Loss') | ||||||
|  |         plt.legend(); plt.grid(True) | ||||||
|  |  | ||||||
|  |         plt.tight_layout() | ||||||
|  |         plt.savefig('loss_curves.png') | ||||||
|  |         print("\nLoss curves saved to loss_curves.png") | ||||||
|  |  | ||||||
|  |     cleanup_ddp() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										218
									
								
								train_iter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								train_iter.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | |||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.optim import AdamW | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | import numpy as np | ||||||
|  | import math | ||||||
|  | import tqdm | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import json | ||||||
|  | import itertools | ||||||
|  |  | ||||||
|  | from models import TimeAwareGPT2, CombinedLoss | ||||||
|  | from utils import PatientEventDataset | ||||||
|  |  | ||||||
|  | # --- Configuration --- | ||||||
|  | class TrainConfig: | ||||||
|  |     # Data parameters | ||||||
|  |     train_data_path = 'ukb_real_train.bin' | ||||||
|  |     val_data_path = 'ukb_real_val.bin' | ||||||
|  |     block_length = 48  # Sequence length | ||||||
|  |  | ||||||
|  |     # Model parameters | ||||||
|  |     n_embd = 120 | ||||||
|  |     n_layer = 12 | ||||||
|  |     n_head = 12 | ||||||
|  |     pdrop = 0.0 | ||||||
|  |     token_pdrop = 0.0 | ||||||
|  |  | ||||||
|  |     # Training parameters | ||||||
|  |     max_iter = 200000 | ||||||
|  |     batch_size = 128 | ||||||
|  |     lr_initial = 6e-4 | ||||||
|  |     lr_final = 6e-5 | ||||||
|  |     weight_decay = 2e-1 | ||||||
|  |     warmup_iter = 1000 | ||||||
|  |      | ||||||
|  |     # Loss parameters | ||||||
|  |     # 0 = padding, 1 = "no event" | ||||||
|  |     ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]  # Example ignored token IDs | ||||||
|  |  | ||||||
|  |     # System parameters | ||||||
|  |     device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||||
|  |  | ||||||
|  | # --- Main Training Script --- | ||||||
|  | def main(): | ||||||
|  |     config = TrainConfig() | ||||||
|  |  | ||||||
|  |     model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt" | ||||||
|  |  | ||||||
|  |     # --- 0. Save Configuration --- | ||||||
|  |     config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json" | ||||||
|  |     config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} | ||||||
|  |     with open(config_filename, 'w') as f: | ||||||
|  |         json.dump(config_dict, f, indent=4) | ||||||
|  |     print(f"Configuration saved to {config_filename}") | ||||||
|  |  | ||||||
|  |     # --- 1. Data Loading --- | ||||||
|  |     print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") | ||||||
|  |     train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) | ||||||
|  |     val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) | ||||||
|  |  | ||||||
|  |     # Infer vocab_size from the data (max label + 1) | ||||||
|  |     vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 | ||||||
|  |     print(f"Inferred vocabulary size: {vocab_size}") | ||||||
|  |  | ||||||
|  |     train_dataset = PatientEventDataset(train_data_arr, config.block_length) | ||||||
|  |     val_dataset = PatientEventDataset(val_data_arr, config.block_length) | ||||||
|  |  | ||||||
|  |     train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True) | ||||||
|  |     val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) | ||||||
|  |     train_iter_loader = iter(itertools.cycle(train_loader)) | ||||||
|  |  | ||||||
|  |     # --- 2. Model, Optimizer, and Loss Initialization --- | ||||||
|  |     print(f"Initializing model on {config.device}...") | ||||||
|  |     model = TimeAwareGPT2( | ||||||
|  |         vocab_size=vocab_size, | ||||||
|  |         n_embd=config.n_embd, | ||||||
|  |         n_layer=config.n_layer, | ||||||
|  |         n_head=config.n_head, | ||||||
|  |         pdrop=config.pdrop, | ||||||
|  |         token_pdrop=config.token_pdrop | ||||||
|  |     ).to(config.device) | ||||||
|  |  | ||||||
|  |     print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") | ||||||
|  |  | ||||||
|  |     loss_fn = CombinedLoss(config.ignored_token_ids) | ||||||
|  |     optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99)) | ||||||
|  |  | ||||||
|  |     # --- 3. Training Loop --- | ||||||
|  |      | ||||||
|  |     # Lists to store losses | ||||||
|  |     train_losses_ce, train_losses_surv, train_losses_total = [], [], [] | ||||||
|  |      | ||||||
|  |     print("Starting training...") | ||||||
|  |     pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training") | ||||||
|  |     for iter_num in pbar: | ||||||
|  |         # --- Learning Rate Scheduling --- | ||||||
|  |         if iter_num < config.warmup_iter: | ||||||
|  |             lr = config.lr_initial | ||||||
|  |         else: | ||||||
|  |             progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter) | ||||||
|  |             lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) | ||||||
|  |          | ||||||
|  |         for param_group in optimizer.param_groups: | ||||||
|  |             param_group['lr'] = lr | ||||||
|  |  | ||||||
|  |         # --- Training Step --- | ||||||
|  |         model.train() | ||||||
|  |          | ||||||
|  |         event_seq, time_seq = next(train_iter_loader) | ||||||
|  |         event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) | ||||||
|  |  | ||||||
|  |         # Prepare inputs and targets | ||||||
|  |         input_events = event_seq[:, :-1] | ||||||
|  |         input_times = time_seq[:, :-1] | ||||||
|  |         target_events = event_seq[:, 1:] | ||||||
|  |         target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() | ||||||
|  |  | ||||||
|  |         # Forward pass | ||||||
|  |         logits = model(input_events, input_times) | ||||||
|  |         loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) | ||||||
|  |         loss = loss_ce + loss_survival | ||||||
|  |  | ||||||
|  |         # Backward pass and optimization | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  |  | ||||||
|  |         train_losses_ce.append(loss_ce.item()) | ||||||
|  |         train_losses_surv.append(loss_survival.item()) | ||||||
|  |         train_losses_total.append(loss.item()) | ||||||
|  |          | ||||||
|  |         pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) | ||||||
|  |  | ||||||
|  |     print("\nTraining finished.") | ||||||
|  |  | ||||||
|  |     # --- 4. Final Validation --- | ||||||
|  |     print("Running final validation...") | ||||||
|  |     model.eval() | ||||||
|  |     val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 | ||||||
|  |     val_steps = 0 | ||||||
|  |  | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         pbar_val = tqdm.tqdm(val_loader, desc="Final Validation") | ||||||
|  |         for event_seq, time_seq in pbar_val: | ||||||
|  |             event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) | ||||||
|  |  | ||||||
|  |             input_events = event_seq[:, :-1] | ||||||
|  |             input_times = time_seq[:, :-1] | ||||||
|  |             target_events = event_seq[:, 1:] | ||||||
|  |             target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() | ||||||
|  |  | ||||||
|  |             logits = model(input_events, input_times) | ||||||
|  |             loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) | ||||||
|  |              | ||||||
|  |             val_loss_ce_acc += loss_ce.item() | ||||||
|  |             val_loss_surv_acc += loss_survival.item() | ||||||
|  |             val_steps += 1 | ||||||
|  |             pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) | ||||||
|  |  | ||||||
|  |     avg_val_loss_ce = val_loss_ce_acc / val_steps | ||||||
|  |     avg_val_loss_surv = val_loss_surv_acc / val_steps | ||||||
|  |     total_val_loss = avg_val_loss_ce + avg_val_loss_surv | ||||||
|  |  | ||||||
|  |     print(f"Final Validation Summary: \n"  | ||||||
|  |           f"  Val Loss:   {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})") | ||||||
|  |  | ||||||
|  |     # --- 5. Save Model --- | ||||||
|  |     print(f"Saving final model to {model_filename}") | ||||||
|  |     torch.save(model.state_dict(), model_filename) | ||||||
|  |  | ||||||
|  |     # --- 6. Save and Plot Losses --- | ||||||
|  |     losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt" | ||||||
|  |     with open(losses_filename, 'w') as f: | ||||||
|  |         f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n") | ||||||
|  |         for i in range(len(train_losses_total)): | ||||||
|  |             f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n") | ||||||
|  |     print(f"\nLosses saved to {losses_filename}") | ||||||
|  |  | ||||||
|  |     # Plot and Save Loss Curves | ||||||
|  |     iterations = range(1, len(train_losses_total) + 1) | ||||||
|  |  | ||||||
|  |     plt.figure(figsize=(18, 5)) | ||||||
|  |  | ||||||
|  |     # Plot CE Loss | ||||||
|  |     plt.subplot(1, 3, 1) | ||||||
|  |     plt.plot(iterations, train_losses_ce, label='Train CE') | ||||||
|  |     plt.title('Cross-Entropy Loss') | ||||||
|  |     plt.xlabel('Iterations') | ||||||
|  |     plt.ylabel('Loss') | ||||||
|  |     plt.legend() | ||||||
|  |     plt.grid(True) | ||||||
|  |  | ||||||
|  |     # Plot Survival Loss | ||||||
|  |     plt.subplot(1, 3, 2) | ||||||
|  |     plt.plot(iterations, train_losses_surv, label='Train Survival') | ||||||
|  |     plt.title('Survival Loss') | ||||||
|  |     plt.xlabel('Iterations') | ||||||
|  |     plt.ylabel('Loss') | ||||||
|  |     plt.legend() | ||||||
|  |     plt.grid(True) | ||||||
|  |  | ||||||
|  |     # Plot Total Loss | ||||||
|  |     plt.subplot(1, 3, 3) | ||||||
|  |     plt.plot(iterations, train_losses_total, label='Train Total') | ||||||
|  |     plt.title('Total Loss') | ||||||
|  |     plt.xlabel('Iterations') | ||||||
|  |     plt.ylabel('Loss') | ||||||
|  |     plt.legend() | ||||||
|  |     plt.grid(True) | ||||||
|  |  | ||||||
|  |     plt.tight_layout() | ||||||
|  |     plt.savefig('loss_curves_iter.png') | ||||||
|  |     print("\nLoss curves saved to loss_curves_iter.png") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										163
									
								
								utils.py
									
									
									
									
									
								
							
							
						
						
									
										163
									
								
								utils.py
									
									
									
									
									
								
							| @@ -1,7 +1,11 @@ | |||||||
|  | import os | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| import random | import random | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  | import json | ||||||
|  | from models import TimeAwareGPT2, TimeAwareGPT2Learnable, TimeAwareGPT2TemporalConv | ||||||
|  |  | ||||||
|  |  | ||||||
| class PatientEventDataset(torch.utils.data.Dataset): | class PatientEventDataset(torch.utils.data.Dataset): | ||||||
|     """ |     """ | ||||||
| @@ -39,17 +43,22 @@ class PatientEventDataset(torch.utils.data.Dataset): | |||||||
|         """ |         """ | ||||||
|         return len(self.patient_ids) |         return len(self.patient_ids) | ||||||
|  |  | ||||||
|     def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: |     def __getitem__(self, idx): | ||||||
|         """ |         """ | ||||||
|         Retrieves, processes, and returns a single patient's event sequence. |         Retrieves, processes, and returns a single patient's event sequence, | ||||||
|  |         or a list of sequences if a slice is provided. | ||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             idx (int): The index of the patient to retrieve. |             idx (int or slice): The index or slice of the patient(s) to retrieve. | ||||||
|  |  | ||||||
|         Returns: |         Returns: | ||||||
|             A tuple of two torch.long tensors: (event_sequence, time_sequence), |             If idx is an int, a tuple of two torch.long tensors: | ||||||
|             both of shape (block_length,). |             (event_sequence, time_sequence), both of shape (block_length,). | ||||||
|  |             If idx is a slice, a list of such tuples. | ||||||
|         """ |         """ | ||||||
|  |         if isinstance(idx, slice): | ||||||
|  |             return [self[i] for i in range(*idx.indices(len(self)))] | ||||||
|  |  | ||||||
|         # 1. Retrieve and Sort |         # 1. Retrieve and Sort | ||||||
|         patient_id = self.patient_ids[idx] |         patient_id = self.patient_ids[idx] | ||||||
|         records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) |         records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) | ||||||
| @@ -102,3 +111,147 @@ class PatientEventDataset(torch.utils.data.Dataset): | |||||||
|         time_tensor = torch.tensor(time_stamps, dtype=torch.long) |         time_tensor = torch.tensor(time_stamps, dtype=torch.long) | ||||||
|  |  | ||||||
|         return event_tensor, time_tensor |         return event_tensor, time_tensor | ||||||
|  |  | ||||||
|  | def load_model(config_path: str, device: str = 'cpu'): | ||||||
|  |     """ | ||||||
|  |         Load a trained model based on the training configuration, inferring the | ||||||
|  |         checkpoint filename from the configuration. | ||||||
|  |  | ||||||
|  |     According to train.py, models may be either 'TimeAwareGPT2' or | ||||||
|  |     'TimeAwareGPT2Learnable'. This function: | ||||||
|  |       - Reads the config JSON to get architecture hyperparameters | ||||||
|  |       - Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent) | ||||||
|  |             - Infers the checkpoint path from the config values | ||||||
|  |             - Infers vocab_size from the checkpoint | ||||||
|  |       - Loads weights and returns the model in eval mode on the requested device | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         config_path: Path to the JSON configuration file saved during training. | ||||||
|  |         device: 'cpu' or 'cuda'. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         torch.nn.Module: Loaded model ready for inference. | ||||||
|  |     """ | ||||||
|  |     # 1) Read config | ||||||
|  |     with open(config_path, 'r') as f: | ||||||
|  |         config_dict = json.load(f) | ||||||
|  |  | ||||||
|  |     # Access config entries with attribute-style access while staying tolerant to missing keys | ||||||
|  |     class AttrDict(dict): | ||||||
|  |         def __getattr__(self, item): | ||||||
|  |             try: | ||||||
|  |                 return self[item] | ||||||
|  |             except KeyError: | ||||||
|  |                 raise AttributeError(item) | ||||||
|  |  | ||||||
|  |     config = AttrDict(config_dict) | ||||||
|  |  | ||||||
|  |     # 2) Decide model class (train.py supports two variants) | ||||||
|  |     model_name = getattr(config, 'model_name', 'TimeAwareGPT2') | ||||||
|  |     model_cls = { | ||||||
|  |         'TimeAwareGPT2': TimeAwareGPT2, | ||||||
|  |         'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, | ||||||
|  |         'TimeAwareGPT2TemporalConv': TimeAwareGPT2TemporalConv, | ||||||
|  |     }.get(model_name, TimeAwareGPT2) | ||||||
|  |  | ||||||
|  |     # 3) Infer checkpoint filename from config | ||||||
|  |     n_embd = getattr(config, 'n_embd') | ||||||
|  |     n_layer = getattr(config, 'n_layer') | ||||||
|  |     n_head = getattr(config, 'n_head') | ||||||
|  |  | ||||||
|  |     # Newer naming (includes model_name) used by train.py when model_name is present | ||||||
|  |     suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}" | ||||||
|  |     ckpt_with_model = f"best_model_{suffix_with_model}.pt" | ||||||
|  |  | ||||||
|  |     # Older naming (without model_name) matches existing repo files | ||||||
|  |     suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}" | ||||||
|  |     ckpt_legacy = f"best_model_{suffix_legacy}.pt" | ||||||
|  |  | ||||||
|  |     # Prefer file that exists on disk | ||||||
|  |     if os.path.exists(ckpt_with_model): | ||||||
|  |         model_path = ckpt_with_model | ||||||
|  |     elif os.path.exists(ckpt_legacy): | ||||||
|  |         model_path = ckpt_legacy | ||||||
|  |     else: | ||||||
|  |         # Fall back to including model_name; if not present in config earlier, user may still have saved this way | ||||||
|  |         model_path = ckpt_with_model | ||||||
|  |         print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}") | ||||||
|  |  | ||||||
|  |     # 4) Infer vocab_size from checkpoint | ||||||
|  |     state_preview = torch.load(model_path, map_location='cpu') | ||||||
|  |     if 'wte.weight' in state_preview: | ||||||
|  |         vocab_size = state_preview['wte.weight'].shape[0] | ||||||
|  |     elif 'head.weight' in state_preview: | ||||||
|  |         vocab_size = state_preview['head.weight'].shape[0] | ||||||
|  |     else: | ||||||
|  |         candidate = None | ||||||
|  |         for k, v in state_preview.items(): | ||||||
|  |             if isinstance(v, torch.Tensor) and v.ndim == 2: | ||||||
|  |                 V = max(v.shape) | ||||||
|  |                 if candidate is None or V > candidate: | ||||||
|  |                     candidate = V | ||||||
|  |         if candidate is None: | ||||||
|  |             raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.") | ||||||
|  |         vocab_size = candidate | ||||||
|  |  | ||||||
|  |     # 5) Build model from config (tolerant to missing fields) | ||||||
|  |     pdrop = getattr(config, 'pdrop', 0.1) | ||||||
|  |     token_pdrop = getattr(config, 'token_pdrop', 0.1) | ||||||
|  |  | ||||||
|  |     model = model_cls( | ||||||
|  |         vocab_size=vocab_size, | ||||||
|  |         n_embd=n_embd, | ||||||
|  |         n_layer=n_layer, | ||||||
|  |         n_head=n_head, | ||||||
|  |         pdrop=pdrop, | ||||||
|  |         token_pdrop=token_pdrop, | ||||||
|  |     ).to(device) | ||||||
|  |  | ||||||
|  |     # 6) Load weights | ||||||
|  |     state_dict = torch.load(model_path, map_location=device) | ||||||
|  |     missing, unexpected = model.load_state_dict(state_dict, strict=False) | ||||||
|  |  | ||||||
|  |     if missing: | ||||||
|  |         print(f"Warning: Missing keys when loading state_dict: {missing}") | ||||||
|  |     if unexpected: | ||||||
|  |         print(f"Warning: Unexpected keys when loading state_dict: {unexpected}") | ||||||
|  |  | ||||||
|  |     model.eval() | ||||||
|  |     try: | ||||||
|  |         num_params_m = model.get_num_params() | ||||||
|  |         print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.") | ||||||
|  |     except Exception: | ||||||
|  |         pass | ||||||
|  |     return model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
|  |     """ | ||||||
|  |     Retrieves a batch of data from a PatientEventDataset and prepares it for model training. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         dataset (PatientEventDataset): The dataset to retrieve data from. | ||||||
|  |         batch_slice (slice): The slice defining the batch of patients to retrieve. | ||||||
|  |         ignore_tokens (list, optional): A list of token IDs to be ignored in the target events. | ||||||
|  |                                         These tokens will be replaced with -100. Defaults to None. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         A tuple containing four tensors: | ||||||
|  |         - input_events: (batch_size, sequence_length - 1) | ||||||
|  |         - input_tims: (batch_size, sequence_length - 1) | ||||||
|  |         - target_events: (batch_size, sequence_length - 1) | ||||||
|  |         - target_times: (batch_size, sequence_length - 1) | ||||||
|  |     """ | ||||||
|  |     batch_data = dataset[batch_slice] | ||||||
|  |      | ||||||
|  |     input_events = [item[0][:-1] for item in batch_data] | ||||||
|  |     input_tims = [item[1][:-1] for item in batch_data] | ||||||
|  |     target_events = [item[0][1:] for item in batch_data] | ||||||
|  |     target_times = [item[1][1:] for item in batch_data] | ||||||
|  |  | ||||||
|  |     input_events = torch.stack(input_events) | ||||||
|  |     input_tims = torch.stack(input_tims) | ||||||
|  |     target_events = torch.stack(target_events) | ||||||
|  |     target_times = torch.stack(target_times) | ||||||
|  |  | ||||||
|  |     return input_events, input_tims, target_events, target_times | ||||||
		Reference in New Issue
	
	Block a user