feat: Add multi-GPU training and improve config/ignore

Add train_multigpu.py for distributed data parallel training.

Update train.py to save the training configuration to a JSON file.

Generalize .gitignore to exclude all *.pt checkpoint files.

Delete obsolete train_dpp.py file.
This commit is contained in:
2025-10-17 14:09:34 +08:00
parent 053f86f4da
commit d760c45baf
4 changed files with 282 additions and 401 deletions

View File

@@ -6,6 +6,7 @@ import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
@@ -47,6 +48,13 @@ def main():
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)