feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training. Update train.py to save the training configuration to a JSON file. Generalize .gitignore to exclude all *.pt checkpoint files. Delete obsolete train_dpp.py file.
This commit is contained in:
8
train.py
8
train.py
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
@@ -47,6 +48,13 @@ def main():
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||
with open(config_filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration saved to {config_filename}")
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
Reference in New Issue
Block a user