This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
2.4 KiB
2.4 KiB
DeepHealth Project
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
1. Project Structure
To maintain a clean and modular project, we adopt the following file organization:
DeepHealth/ |-tain.py |-models.py |-utils.py |-data/ |-requirements.txt |-README.md
File Descriptions
-
train.py
:- Core training script. It contains the control flow for the entire training process.
- Responsible for initializing the model, optimizer, DataLoader, etc.
- Executes the training and validation loops.
- Handles saving and loading checkpoints, logging, and other related tasks.
-
models.py
:- Model and Loss Function Definitions. This file stores the architecture for all neural network models.
- All subclasses of
torch.nn.Module
should be defined in this file. - Custom loss functions should also be implemented here.
-
utils.py
:- Utility Functions Module. It contains reusable helper functions for the project.
- Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
-
data/
:- Data Storage Directory. Used to store the datasets required for the project.
data/raw/
stores the original, unprocessed data.data/processed/
stores data after it has been preprocessed.
-
requirements.txt
:- Project Dependencies. Lists all the Python packages and their versions required to run this project.
-
README.md
:- Project Documentation. Provides a high-level overview of the project, setup instructions, and usage guidelines.
2. Core Framework
- Deep Learning Framework:
PyTorch
3. Coding Style
This project uniformly adopts the Google Python Style Guide. All submitted code should adhere to this standard to ensure consistency and readability.
Key features include:
- Using
yapf
orblack
for automatic code formatting. - Following detailed naming conventions (
module_name
,package_name
,ClassName
,method_name
,ExceptionName
,function_name
,GLOBAL_CONSTANT_NAME
). - Using Google-style docstrings.
Please refer to the official documentation: Google Python Style Guide