Add train_multigpu.py for distributed data parallel training.
Update train.py to save the training configuration to a JSON file.
Generalize .gitignore to exclude all *.pt checkpoint files.
Delete obsolete train_dpp.py file.
In `models.py`:
- Change temporal attention mask to be strictly causal (`<` instead of `<=`).
- Add self-attention for the first token in a sequence to prevent NaNs.
In `train.py`:
- Update hyperparameters:
- `block_length`: 24 -> 48
- `n_embd`: 256 -> 120
- `n_layer`: 8 -> 12
- `n_head`: 8 -> 12
- Refactor the self-attention mechanism in `models.py` to use `nn.MultiheadAttention` for better performance and clarity.
- Disable early stopping check during warmup epochs in `train.py` to improve training stability.
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data.
Key components include:
- `models.py`:
- `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information.
- `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings.
- `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing.
- `utils.py`:
- `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation.
- `train.py`:
- A comprehensive training script that initializes the model, data loaders, and loss function.
- Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss.
- `prepare_data.py`:
- Script for preprocessing raw UK Biobank data into a format suitable for the model.
- `GEMINI.md`:
- Project documentation outlining the structure, coding style, and framework.