Files
DeepHealth/prepare_data.py
Jiarui Li 589d4d0bd2 feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data.

Key components include:

- `models.py`:
  - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information.
  - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings.
  - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing.

- `utils.py`:
  - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation.

- `train.py`:
  - A comprehensive training script that initializes the model, data loaders, and loss function.
  - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss.

- `prepare_data.py`:
  - Script for preprocessing raw UK Biobank data into a format suitable for the model.

- `GEMINI.md`:
  - Project documentation outlining the structure, coding style, and framework.
2025-10-16 14:21:36 +08:00

134 lines
5.4 KiB
Python

import pandas as pd
import tqdm
import numpy as np
label_files = 'labels.csv'
ukb_field_to_icd10_file = 'icd10_codes_mod.tsv'
ukb_basket_file = 'ukb_delphi.txt'
train_proportion = 0.8
output_prefix = 'ukb_real'
icdict = {}
icdcodes = []
with open(ukb_field_to_icd10_file) as f:
for line in f:
parts = line.strip().split()
icdict[parts[0]] = parts[5]
icdcodes.append(parts[5])
# Using enumerate for cleaner, safer label assignment starting from 0
label_dict = {}
with open(label_files) as f:
for i, line in enumerate(f):
label_dict[line.strip().split(' ')[0]] = i
icdict['f.31.0.0'] = "sex"
icdict['f.34.0.0'] = "YEAR"
icdict['f.52.0.0'] = "MONTH"
icdict['f.40000.0.0'] = "Death"
for j in range(17):
icdict[f'f.40005.{j}.0'] = f'cancer_date_{j}'
icdict[f'f.40006.{j}.0'] = f'cancer_type_{j}'
icdict['f.53.0.0'] = "assessment_date"
icdict['f.21001.0.0'] = "BMI"
icdict['f.1239.0.0'] = "smoking"
icdict['f.1558.0.0'] = "alcohol"
len_icd = len(icdcodes)
# Corrected typo 'aseessment_date' to 'assessment_date'
icdcodes.extend(['Death', 'assessment_date'] + [f'cancer_date_{j}' for j in range(17)])
data_list = []
ukb_iterator = pd.read_csv(ukb_basket_file, sep=',', chunksize=10000, index_col=0, low_memory=False)
for _, dd in tqdm.tqdm(enumerate(ukb_iterator)):
dd = dd.rename(columns=icdict)
dd.dropna(subset=['sex'], inplace=True)
dd['sex'] += 1
dd = dd[[col for col in dd.columns if not col.startswith('f.')]]
dd['dob'] = pd.to_datetime(dd[['YEAR', 'MONTH']].assign(DAY=1))
present_icdcodes = [c for c in icdcodes if c in dd.columns]
if present_icdcodes:
# Convert date columns to days from date of birth
date_cols = dd[present_icdcodes].apply(pd.to_datetime, format="%Y-%m-%d", errors='coerce')
date_cols_days = date_cols.sub(dd['dob'], axis=0)
dd[present_icdcodes] = date_cols_days.apply(lambda x: x.dt.days)
# Process ICD codes efficiently using melt
cols_to_process = [col for col in icdcodes[:len_icd + 1] if col in dd.columns]
if cols_to_process:
melted_df = dd.reset_index().melt(
id_vars=['f.eid'],
value_vars=cols_to_process,
var_name='event_code',
value_name='days'
)
melted_df.dropna(subset=['days'], inplace=True)
if not melted_df.empty:
melted_df['label'] = melted_df['event_code'].map(label_dict)
data_list.append(melted_df[['f.eid', 'days', 'label']].dropna().astype(int).to_numpy())
# Process sex
X = dd['sex'].reset_index().to_numpy().astype(int)
data_list.append(np.c_[X[:, 0], np.zeros(X.shape[0]), X[:, 1]].astype(int))
# Process cancer data efficiently using wide_to_long
df_res = dd.reset_index()
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
rename_dict.update({f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
df_renamed = df_res.rename(columns=rename_dict)
stubs_to_use = []
if any('cancerdate' in col for col in df_renamed.columns): stubs_to_use.append('cancerdate')
if any('cancertype' in col for col in df_renamed.columns): stubs_to_use.append('cancertype')
if len(stubs_to_use) == 2:
long_cancer = pd.wide_to_long(df_renamed,
stubnames=stubs_to_use,
i=['f.eid'],
j='cancer_num'
).dropna()
if not long_cancer.empty:
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(0, 3)
long_cancer['cancer_label'] = long_cancer['cancer'].map(label_dict)
cancer_array = long_cancer.reset_index()[['f.eid', 'cancerdate', 'cancer_label']].dropna().astype(int).to_numpy()
if cancer_array.size > 0:
data_list.append(cancer_array)
# Process BMI, smoking, and alcohol
dd_bmi = dd[['assessment_date', 'BMI']].dropna().reset_index()
if not dd_bmi.empty:
dd_bmi['bmi_status'] = np.select([dd_bmi['BMI'] > 28, dd_bmi['BMI'] > 22], [5, 4], default=3)
data_list.append(dd_bmi[['f.eid', 'assessment_date', 'bmi_status']].astype(int).to_numpy())
dd_sm = dd[['assessment_date', 'smoking']].dropna().reset_index()
dd_sm = dd_sm[dd_sm['smoking'] != -3]
if not dd_sm.empty:
dd_sm['smoking_status'] = np.select([dd_sm['smoking'] == 1, dd_sm['smoking'] == 2], [8, 7], default=6)
data_list.append(dd_sm[['f.eid', 'assessment_date', 'smoking_status']].astype(int).to_numpy())
dd_al = dd[['assessment_date', 'alcohol']].dropna().reset_index()
dd_al = dd_al[dd_al['alcohol'] != -3]
if not dd_al.empty:
dd_al['alcohol_status'] = np.select([dd_al['alcohol'] == 1, dd_al['alcohol'] < 4], [11, 10], default=9)
data_list.append(dd_al[['f.eid', 'assessment_date', 'alcohol_status']].astype(int).to_numpy())
data = np.vstack(data_list)
data = data[np.lexsort((data[:, 1], data[:, 2] == data[:, 2].max(), data[:, 0]))]
data = data[data[:, 1] >= 0]
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
data = data.astype(np.uint32)
data.tofile(output_prefix + '.bin')
# Correctly split train/validation sets
unique_ids = np.unique(data[:, 0])
split_id = unique_ids[int(len(unique_ids) * train_proportion)]
train_val_split = data[:, 0] <= split_id
data[train_val_split].tofile(output_prefix + '_train.bin')
data[~train_val_split].tofile(output_prefix + '_val.bin')