feat: Add multi-GPU training and improve config/ignore
Add train_multigpu.py for distributed data parallel training. Update train.py to save the training configuration to a JSON file. Generalize .gitignore to exclude all *.pt checkpoint files. Delete obsolete train_dpp.py file.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,7 +5,7 @@
|
||||
__pycache__/
|
||||
|
||||
# Model checkpoints
|
||||
best_model_checkpoint.pt
|
||||
*.pt
|
||||
|
||||
# Large data files
|
||||
ukb_delphi.txt
|
||||
|
8
train.py
8
train.py
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
@@ -47,6 +48,13 @@ def main():
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||
with open(config_filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration saved to {config_filename}")
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
400
train_dpp.py
400
train_dpp.py
@@ -1,400 +0,0 @@
|
||||
# train.py (DDP-ready)
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 24 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
n_layer = 8
|
||||
n_head = 8
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 5
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
ignored_token_ids = [0, 1]
|
||||
|
||||
# System parameters (device 将在 main() 内按 local_rank 动态设置)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
def setup_distributed(backend: str = "nccl"):
|
||||
"""
|
||||
如果由 torchrun 启动且 WORLD_SIZE>1,则初始化分布式。
|
||||
返回 (is_distributed, world_size, rank, local_rank)
|
||||
"""
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
is_distributed = world_size > 1
|
||||
if is_distributed:
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
rank = dist.get_rank()
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
rank = 0
|
||||
local_rank = 0
|
||||
return is_distributed, world_size, rank, local_rank
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def all_reduce_mean(value: float, device, world_size: int):
|
||||
"""
|
||||
value 是 Python float(本进程的和/均值),返回所有进程平均后的 float。
|
||||
"""
|
||||
tensor = torch.tensor([value], dtype=torch.float32, device=device)
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor /= world_size
|
||||
return float(tensor.item())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 分布式初始化
|
||||
is_dist, world_size, rank, local_rank = setup_distributed(args.backend)
|
||||
|
||||
# 基本环境
|
||||
torch.manual_seed(args.seed + rank)
|
||||
np.random.seed(args.seed + rank)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = TrainConfig()
|
||||
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
config.device = device
|
||||
|
||||
is_main = (rank == 0)
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
if is_main:
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
# Infer vocab_size from the data (max label + 1)
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
if is_main:
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
# 分布式采样器
|
||||
if is_dist:
|
||||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
||||
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
||||
else:
|
||||
train_sampler = None
|
||||
val_sampler = None
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
persistent_workers=True if 4 > 0 else False,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
sampler=val_sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
persistent_workers=True if 4 > 0 else False,
|
||||
)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
if is_main:
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(device)
|
||||
|
||||
if is_main and hasattr(model, "get_num_params"):
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
|
||||
# DDP 包装
|
||||
if is_dist:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[local_rank] if device.type == "cuda" else None,
|
||||
output_device=local_rank if device.type == "cuda" else None,
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
# 只在主进程收集与画图
|
||||
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
||||
|
||||
if is_main:
|
||||
print("Starting training...")
|
||||
|
||||
stop_training = False
|
||||
|
||||
for epoch in range(config.max_epoch):
|
||||
# 设置 epoch 给分布式采样器,确保跨 epoch shuffle
|
||||
if is_dist:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
# --- Learning Rate Scheduling ---
|
||||
if epoch < config.warmup_epochs:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# --- Training Phase ---
|
||||
if is_main:
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||
else:
|
||||
pbar = train_loader # 非主进程禁用 tqdm
|
||||
|
||||
model.train()
|
||||
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||
train_steps = 0
|
||||
|
||||
for batch in pbar:
|
||||
event_seq, time_seq = batch
|
||||
event_seq = event_seq.to(device, non_blocking=True)
|
||||
time_seq = time_seq.to(device, non_blocking=True)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
# Forward pass
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss_ce_acc += float(loss_ce.item())
|
||||
train_loss_surv_acc += float(loss_survival.item())
|
||||
train_steps += 1
|
||||
|
||||
if is_main and isinstance(pbar, tqdm.tqdm):
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
# 进程内均值
|
||||
avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1)
|
||||
avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1)
|
||||
|
||||
# 所有进程平均
|
||||
if is_dist:
|
||||
avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size)
|
||||
avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size)
|
||||
else:
|
||||
avg_train_loss_ce = avg_train_loss_ce_local
|
||||
avg_train_loss_surv = avg_train_loss_surv_local
|
||||
|
||||
if is_main:
|
||||
train_losses_ce.append(avg_train_loss_ce)
|
||||
train_losses_surv.append(avg_train_loss_surv)
|
||||
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
|
||||
|
||||
# --- Validation Phase ---
|
||||
if is_main:
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||
else:
|
||||
pbar_val = val_loader
|
||||
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in pbar_val:
|
||||
event_seq, time_seq = batch
|
||||
event_seq = event_seq.to(device, non_blocking=True)
|
||||
time_seq = time_seq.to(device, non_blocking=True)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += float(loss_ce.item())
|
||||
val_loss_surv_acc += float(loss_survival.item())
|
||||
val_steps += 1
|
||||
|
||||
if is_main and isinstance(pbar_val, tqdm.tqdm):
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1)
|
||||
avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1)
|
||||
|
||||
if is_dist:
|
||||
avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size)
|
||||
avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size)
|
||||
else:
|
||||
avg_val_loss_ce = avg_val_loss_ce_local
|
||||
avg_val_loss_surv = avg_val_loss_surv_local
|
||||
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
|
||||
# 主进程打印与记录
|
||||
if is_main:
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
val_losses_ce.append(avg_val_loss_ce)
|
||||
val_losses_surv.append(avg_val_loss_surv)
|
||||
val_losses_total.append(total_val_loss)
|
||||
|
||||
# --- Early Stopping Check (基于聚合后的 total_val_loss) ---
|
||||
improved = False
|
||||
if is_main:
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
improved = True
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
# DDP: 保存 module.state_dict()
|
||||
state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict()
|
||||
torch.save(state_dict, 'best_model_checkpoint.pt')
|
||||
else:
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
stop_training = patience_counter >= config.early_stopping_patience
|
||||
|
||||
# 把 improved/stop 广播到所有进程,确保一致退出
|
||||
if is_dist:
|
||||
flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32)
|
||||
dist.broadcast(flag_tensor, src=0)
|
||||
stop_training = bool(int(flag_tensor.item()))
|
||||
|
||||
if stop_training:
|
||||
if is_main:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
# --- Save Best Model at the End (只主进程) ---
|
||||
if is_main:
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
# 为了易用,这里在主进程上重新构建单卡模型加载权重再保存
|
||||
model_single = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to('cpu')
|
||||
model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
|
||||
print("Saving final best model to best_model.pt")
|
||||
torch.save(model_single.state_dict(), 'best_model.pt')
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
# --- Plot and Save Loss Curves ---
|
||||
num_epochs = len(train_losses_total)
|
||||
if num_epochs > 0:
|
||||
epochs = range(1, num_epochs + 1)
|
||||
plt.figure(figsize=(18, 5))
|
||||
|
||||
# Plot CE Loss
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(epochs, train_losses_ce, label='Train CE')
|
||||
plt.plot(epochs, val_losses_ce, label='Val CE')
|
||||
plt.title('Cross-Entropy Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Survival Loss
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
||||
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
||||
plt.title('Survival Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Plot Total Loss
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(epochs, train_losses_total, label='Train Total')
|
||||
plt.plot(epochs, val_losses_total, label='Val Total')
|
||||
plt.title('Total Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('loss_curves.png')
|
||||
print("\nLoss curves saved to loss_curves.png")
|
||||
|
||||
# 清理分布式
|
||||
cleanup_distributed()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
273
train_multigpu.py
Normal file
273
train_multigpu.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 48 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 120
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
weight_decay = 2e-1
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 10
|
||||
|
||||
# Loss parameters
|
||||
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
|
||||
# System parameters (will be updated by DDP setup)
|
||||
device = 'cuda'
|
||||
|
||||
# --- DDP Setup ---
|
||||
def setup_ddp():
|
||||
"""Initializes the distributed data parallel environment."""
|
||||
dist.init_process_group(backend='nccl')
|
||||
rank = dist.get_rank()
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
return rank, local_rank
|
||||
|
||||
def cleanup_ddp():
|
||||
"""Cleans up the distributed data parallel environment."""
|
||||
dist.destroy_process_group()
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
rank, local_rank = setup_ddp()
|
||||
is_main_process = (rank == 0)
|
||||
|
||||
config = TrainConfig()
|
||||
config.device = f'cuda:{local_rank}'
|
||||
|
||||
if is_main_process:
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
|
||||
# --- 0. Save Configuration ---
|
||||
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json"
|
||||
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||
with open(config_filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration saved to {config_filename}")
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
if is_main_process:
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
if is_main_process:
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
if is_main_process:
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(config.device)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
if is_main_process:
|
||||
print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
||||
|
||||
if is_main_process:
|
||||
print("Starting training...")
|
||||
for epoch in range(config.max_epoch):
|
||||
train_sampler.set_epoch(epoch) # Important for shuffling
|
||||
|
||||
if epoch < config.warmup_epochs:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
model.train()
|
||||
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||
train_steps = 0
|
||||
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]", disable=not is_main_process)
|
||||
for event_seq, time_seq in pbar:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events, input_times = event_seq[:, :-1], time_seq[:, :-1]
|
||||
target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss_ce_acc += loss_ce.item()
|
||||
train_loss_surv_acc += loss_survival.item()
|
||||
train_steps += 1
|
||||
if is_main_process:
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
||||
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
||||
train_losses_ce.append(avg_train_loss_ce)
|
||||
train_losses_surv.append(avg_train_loss_surv)
|
||||
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
|
||||
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]", disable=not is_main_process)
|
||||
for event_seq, time_seq in pbar_val:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events, input_times = event_seq[:, :-1], time_seq[:, :-1]
|
||||
target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += loss_ce.item()
|
||||
val_loss_surv_acc += loss_survival.item()
|
||||
val_steps += 1
|
||||
if is_main_process:
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
val_losses_ce.append(avg_val_loss_ce)
|
||||
val_losses_surv.append(avg_val_loss_surv)
|
||||
val_losses_total.append(total_val_loss)
|
||||
|
||||
if is_main_process:
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
torch.save(model.module.state_dict(), checkpoint_filename)
|
||||
else:
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
if is_main_process:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
if is_main_process:
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
# Load the best weights into the module before saving the final model
|
||||
model.module.load_state_dict(torch.load(checkpoint_filename))
|
||||
print(f"Saving final best model to {model_filename}")
|
||||
torch.save(model.module.state_dict(), model_filename)
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
|
||||
with open(losses_filename, 'w') as f:
|
||||
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||
for i in range(len(train_losses_total)):
|
||||
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
||||
print(f"\nLosses saved to {losses_filename}")
|
||||
|
||||
num_epochs = len(train_losses_total)
|
||||
epochs = range(1, num_epochs + 1)
|
||||
plt.figure(figsize=(18, 5))
|
||||
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(epochs, train_losses_ce, label='Train CE')
|
||||
plt.plot(epochs, val_losses_ce, label='Val CE')
|
||||
plt.title('Cross-Entropy Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
||||
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
||||
plt.title('Survival Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(epochs, train_losses_total, label='Train Total')
|
||||
plt.plot(epochs, val_losses_total, label='Val Total')
|
||||
plt.title('Total Loss')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('loss_curves.png')
|
||||
print("\nLoss curves saved to loss_curves.png")
|
||||
|
||||
cleanup_ddp()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user