589d4d0bd265300056a9bd20d3a1c2e9bd2ffe6e
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
DeepHealth
Description
Languages
Python
60.9%
Jupyter Notebook
39.1%