feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
59
GEMINI.md
Normal file
59
GEMINI.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# DeepHealth Project
|
||||
|
||||
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
|
||||
|
||||
## 1. Project Structure
|
||||
|
||||
To maintain a clean and modular project, we adopt the following file organization:
|
||||
|
||||
DeepHealth/
|
||||
|-tain.py
|
||||
|-models.py
|
||||
|-utils.py
|
||||
|-data/
|
||||
|-requirements.txt
|
||||
|-README.md
|
||||
|
||||
|
||||
### File Descriptions
|
||||
|
||||
* **`train.py`**:
|
||||
* **Core training script**. It contains the control flow for the entire training process.
|
||||
* Responsible for initializing the model, optimizer, DataLoader, etc.
|
||||
* Executes the training and validation loops.
|
||||
* Handles saving and loading checkpoints, logging, and other related tasks.
|
||||
|
||||
* **`models.py`**:
|
||||
* **Model and Loss Function Definitions**. This file stores the architecture for all neural network models.
|
||||
* All subclasses of `torch.nn.Module` should be defined in this file.
|
||||
* Custom loss functions should also be implemented here.
|
||||
|
||||
* **`utils.py`**:
|
||||
* **Utility Functions Module**. It contains reusable helper functions for the project.
|
||||
* Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
|
||||
|
||||
* **`data/`**:
|
||||
* **Data Storage Directory**. Used to store the datasets required for the project.
|
||||
* `data/raw/` stores the original, unprocessed data.
|
||||
* `data/processed/` stores data after it has been preprocessed.
|
||||
|
||||
* **`requirements.txt`**:
|
||||
* **Project Dependencies**. Lists all the Python packages and their versions required to run this project.
|
||||
|
||||
* **`README.md`**:
|
||||
* **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines.
|
||||
|
||||
## 2. Core Framework
|
||||
|
||||
* **Deep Learning Framework**: `PyTorch`
|
||||
|
||||
## 3. Coding Style
|
||||
|
||||
This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability.
|
||||
|
||||
Key features include:
|
||||
* Using `yapf` or `black` for automatic code formatting.
|
||||
* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`).
|
||||
* Using Google-style docstrings.
|
||||
|
||||
Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html)
|
Reference in New Issue
Block a user