Jiarui Li 589d4d0bd2 feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data.

Key components include:

- `models.py`:
  - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information.
  - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings.
  - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing.

- `utils.py`:
  - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation.

- `train.py`:
  - A comprehensive training script that initializes the model, data loaders, and loss function.
  - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss.

- `prepare_data.py`:
  - Script for preprocessing raw UK Biobank data into a format suitable for the model.

- `GEMINI.md`:
  - Project documentation outlining the structure, coding style, and framework.
2025-10-16 14:21:36 +08:00
2025-10-15 13:54:52 +08:00
2025-10-15 13:54:52 +08:00

DeepHealth

Description
No description provided
Readme BSD-3-Clause 145 MiB
Languages
Python 60.9%
Jupyter Notebook 39.1%