This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
59 lines
2.4 KiB
Markdown
59 lines
2.4 KiB
Markdown
# DeepHealth Project
|
|
|
|
This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility.
|
|
|
|
## 1. Project Structure
|
|
|
|
To maintain a clean and modular project, we adopt the following file organization:
|
|
|
|
DeepHealth/
|
|
|-tain.py
|
|
|-models.py
|
|
|-utils.py
|
|
|-data/
|
|
|-requirements.txt
|
|
|-README.md
|
|
|
|
|
|
### File Descriptions
|
|
|
|
* **`train.py`**:
|
|
* **Core training script**. It contains the control flow for the entire training process.
|
|
* Responsible for initializing the model, optimizer, DataLoader, etc.
|
|
* Executes the training and validation loops.
|
|
* Handles saving and loading checkpoints, logging, and other related tasks.
|
|
|
|
* **`models.py`**:
|
|
* **Model and Loss Function Definitions**. This file stores the architecture for all neural network models.
|
|
* All subclasses of `torch.nn.Module` should be defined in this file.
|
|
* Custom loss functions should also be implemented here.
|
|
|
|
* **`utils.py`**:
|
|
* **Utility Functions Module**. It contains reusable helper functions for the project.
|
|
* Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework.
|
|
|
|
* **`data/`**:
|
|
* **Data Storage Directory**. Used to store the datasets required for the project.
|
|
* `data/raw/` stores the original, unprocessed data.
|
|
* `data/processed/` stores data after it has been preprocessed.
|
|
|
|
* **`requirements.txt`**:
|
|
* **Project Dependencies**. Lists all the Python packages and their versions required to run this project.
|
|
|
|
* **`README.md`**:
|
|
* **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines.
|
|
|
|
## 2. Core Framework
|
|
|
|
* **Deep Learning Framework**: `PyTorch`
|
|
|
|
## 3. Coding Style
|
|
|
|
This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability.
|
|
|
|
Key features include:
|
|
* Using `yapf` or `black` for automatic code formatting.
|
|
* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`).
|
|
* Using Google-style docstrings.
|
|
|
|
Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html) |