Commit Graph

36 Commits

Author SHA1 Message Date
ddb7dbfc67 update 2025-10-20 16:22:15 +08:00
88cccdad2e feat: Optimize AUC evaluation with parallel processing 2025-10-20 16:16:50 +08:00
8f44018bae update evaluate 2025-10-20 13:47:50 +08:00
1c9e2a2fb3 feat: print model config and add evaluation notebook 2025-10-20 10:14:50 +08:00
6b782b86e1 feat: Add model checkpoints and configurations 2025-10-20 09:38:24 +08:00
9a9de170d1 delete 2025-10-18 22:35:42 +08:00
7e57e5d3b1 refactor: Update survival loss calculation in CombinedLoss 2025-10-18 15:21:10 +08:00
14865ac5b6 Refactor: Remove Jupyter Notebook cell markers 2025-10-18 13:32:26 +08:00
dbc3000192 add evaluation scripts. 2025-10-18 13:26:56 +08:00
082c719975 feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling 2025-10-18 12:42:14 +08:00
a631ac6d59 feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
2025-10-18 11:07:59 +08:00
f7356b183c feat: Add command-line arguments to train.py 2025-10-18 10:23:12 +08:00
3390bc025e feat: Add iteration-based training scripts (single and multi-GPU) 2025-10-18 10:05:37 +08:00
a832a45c62 config: Tune hyperparameters for multi-GPU training
Increase model size (n_embd, n_layer, n_head) for the multi-GPU configuration.

Explicitly set AdamW betas to (0.9, 0.99).
2025-10-17 15:37:42 +08:00
d760c45baf feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
2025-10-17 14:09:34 +08:00
053f86f4da config: Add weight decay to training configuration
Adds a weight_decay parameter to the TrainConfig and applies it to the AdamW optimizer.
2025-10-17 13:47:37 +08:00
d4d25ac9c7 feat: Add covariate-aware model and piecewise encoder
Introduce PiecewiseLinearEncoder for continuous variable encoding.

Add CovariateAwareGPT2 to extend TimeAwareGPT2 with static and time-varying covariate processing.

The model combines piecewise linear and sinusoidal encodings for covariates and integrates them via concatenation before a final MLP head.

Reorganize models.py for better logical structure.
2025-10-17 12:04:50 +08:00
fe0304a96a feat: Save model with params in name and log losses 2025-10-17 10:44:17 +08:00
7e8d8d307b chore: Ignore small data files 2025-10-17 10:34:24 +08:00
fc0aef4e71 chore: Add .gitignore 2025-10-17 10:32:42 +08:00
02d84a7eca refactor: Use AdamW optimizer and increase early stopping patience 2025-10-17 10:31:12 +08:00
cb7575a229 feat: Update model and training parameters
In `models.py`:
- Change temporal attention mask to be strictly causal (`<` instead of `<=`).
- Add self-attention for the first token in a sequence to prevent NaNs.

In `train.py`:
- Update hyperparameters:
  - `block_length`: 24 -> 48
  - `n_embd`: 256 -> 120
  - `n_layer`: 8 -> 12
  - `n_head`: 8 -> 12
2025-10-16 18:50:15 +08:00
e2495f43b0 revert 6e0713048a
revert update attn mask
2025-10-16 18:37:55 +08:00
6e0713048a update attn mask 2025-10-16 18:29:48 +08:00
eec406d79f update ignored events. 2025-10-16 17:10:01 +08:00
e3e533c9ec update 2025-10-16 16:58:30 +08:00
b5172392cb update dpp 2025-10-16 16:46:33 +08:00
6b0b86d9d0 add Multi_GPU support. 2025-10-16 16:28:52 +08:00
c7296381b8 Revert "feat: adapt train.py to multi-GPU environment"
This reverts commit b7aad7a774.
2025-10-16 16:23:38 +08:00
2b20299e36 Revert "fix: average loss for multi-GPU training"
This reverts commit 85502561ee.
2025-10-16 16:23:35 +08:00
85502561ee fix: average loss for multi-GPU training 2025-10-16 16:21:51 +08:00
b7aad7a774 feat: adapt train.py to multi-GPU environment 2025-10-16 16:16:15 +08:00
4181ead03a Refactor: Improve attention mechanism and early stopping
- Refactor the self-attention mechanism in `models.py` to use `nn.MultiheadAttention` for better performance and clarity.
- Disable early stopping check during warmup epochs in `train.py` to improve training stability.
2025-10-16 15:57:27 +08:00
8a757a8b1d feat: Add training and validation data via Git LFS 2025-10-16 14:24:56 +08:00
589d4d0bd2 feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data.

Key components include:

- `models.py`:
  - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information.
  - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings.
  - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing.

- `utils.py`:
  - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation.

- `train.py`:
  - A comprehensive training script that initializes the model, data loaders, and loss function.
  - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss.

- `prepare_data.py`:
  - Script for preprocessing raw UK Biobank data into a format suitable for the model.

- `GEMINI.md`:
  - Project documentation outlining the structure, coding style, and framework.
2025-10-16 14:21:36 +08:00
1d4731ae42 Initial commit 2025-10-15 13:54:52 +08:00