feat: Add iteration-based training scripts (single and multi-GPU)
This commit is contained in:
218
train_iter.py
Normal file
218
train_iter.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from models import TimeAwareGPT2, CombinedLoss
|
||||||
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
class TrainConfig:
|
||||||
|
# Data parameters
|
||||||
|
train_data_path = 'ukb_real_train.bin'
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
block_length = 48 # Sequence length
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
n_embd = 120
|
||||||
|
n_layer = 12
|
||||||
|
n_head = 12
|
||||||
|
pdrop = 0.1
|
||||||
|
token_pdrop = 0.1
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
max_iter = 200000
|
||||||
|
batch_size = 128
|
||||||
|
lr_initial = 6e-4
|
||||||
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
|
warmup_iter = 1000
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
# 0 = padding, 1 = "no event"
|
||||||
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# --- Main Training Script ---
|
||||||
|
def main():
|
||||||
|
config = TrainConfig()
|
||||||
|
|
||||||
|
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
|
||||||
|
|
||||||
|
# --- 0. Save Configuration ---
|
||||||
|
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
|
||||||
|
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||||
|
with open(config_filename, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
print(f"Configuration saved to {config_filename}")
|
||||||
|
|
||||||
|
# --- 1. Data Loading ---
|
||||||
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
|
||||||
|
# Infer vocab_size from the data (max label + 1)
|
||||||
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||||
|
print(f"Inferred vocabulary size: {vocab_size}")
|
||||||
|
|
||||||
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||||
|
train_iter_loader = iter(itertools.cycle(train_loader))
|
||||||
|
|
||||||
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
|
print(f"Initializing model on {config.device}...")
|
||||||
|
model = TimeAwareGPT2(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=config.n_embd,
|
||||||
|
n_layer=config.n_layer,
|
||||||
|
n_head=config.n_head,
|
||||||
|
pdrop=config.pdrop,
|
||||||
|
token_pdrop=config.token_pdrop
|
||||||
|
).to(config.device)
|
||||||
|
|
||||||
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
|
||||||
|
|
||||||
|
# --- 3. Training Loop ---
|
||||||
|
|
||||||
|
# Lists to store losses
|
||||||
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||||
|
|
||||||
|
print("Starting training...")
|
||||||
|
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
|
||||||
|
for iter_num in pbar:
|
||||||
|
# --- Learning Rate Scheduling ---
|
||||||
|
if iter_num < config.warmup_iter:
|
||||||
|
lr = config.lr_initial
|
||||||
|
else:
|
||||||
|
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
|
||||||
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# --- Training Step ---
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
event_seq, time_seq = next(train_iter_loader)
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
# Prepare inputs and targets
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_losses_ce.append(loss_ce.item())
|
||||||
|
train_losses_surv.append(loss_survival.item())
|
||||||
|
train_losses_total.append(loss.item())
|
||||||
|
|
||||||
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||||
|
|
||||||
|
print("\nTraining finished.")
|
||||||
|
|
||||||
|
# --- 4. Final Validation ---
|
||||||
|
print("Running final validation...")
|
||||||
|
model.eval()
|
||||||
|
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||||
|
val_steps = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
|
||||||
|
for event_seq, time_seq in pbar_val:
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
|
||||||
|
val_loss_ce_acc += loss_ce.item()
|
||||||
|
val_loss_surv_acc += loss_survival.item()
|
||||||
|
val_steps += 1
|
||||||
|
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||||
|
|
||||||
|
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||||
|
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||||
|
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||||
|
|
||||||
|
print(f"Final Validation Summary: \n"
|
||||||
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
|
||||||
|
|
||||||
|
# --- 5. Save Model ---
|
||||||
|
print(f"Saving final model to {model_filename}")
|
||||||
|
torch.save(model.state_dict(), model_filename)
|
||||||
|
|
||||||
|
# --- 6. Save and Plot Losses ---
|
||||||
|
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
|
# Plot and Save Loss Curves
|
||||||
|
iterations = range(1, len(train_losses_total) + 1)
|
||||||
|
|
||||||
|
plt.figure(figsize=(18, 5))
|
||||||
|
|
||||||
|
# Plot CE Loss
|
||||||
|
plt.subplot(1, 3, 1)
|
||||||
|
plt.plot(iterations, train_losses_ce, label='Train CE')
|
||||||
|
plt.title('Cross-Entropy Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Survival Loss
|
||||||
|
plt.subplot(1, 3, 2)
|
||||||
|
plt.plot(iterations, train_losses_surv, label='Train Survival')
|
||||||
|
plt.title('Survival Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Total Loss
|
||||||
|
plt.subplot(1, 3, 3)
|
||||||
|
plt.plot(iterations, train_losses_total, label='Train Total')
|
||||||
|
plt.title('Total Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('loss_curves_iter.png')
|
||||||
|
print("\nLoss curves saved to loss_curves_iter.png")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
247
train_iter_multigpu.py
Normal file
247
train_iter_multigpu.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from models import TimeAwareGPT2, CombinedLoss
|
||||||
|
from utils import PatientEventDataset
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
class TrainConfig:
|
||||||
|
# Data parameters
|
||||||
|
train_data_path = 'ukb_real_train.bin'
|
||||||
|
val_data_path = 'ukb_real_val.bin'
|
||||||
|
block_length = 48 # Sequence length
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
n_embd = 120
|
||||||
|
n_layer = 12
|
||||||
|
n_head = 12
|
||||||
|
pdrop = 0.1
|
||||||
|
token_pdrop = 0.1
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
max_iter = 200000
|
||||||
|
batch_size = 128 # Per GPU
|
||||||
|
lr_initial = 6e-4
|
||||||
|
lr_final = 6e-5
|
||||||
|
weight_decay = 2e-1
|
||||||
|
warmup_iter = 1000
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
# 0 = padding, 1 = "no event"
|
||||||
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
# --- DDP Setup ---
|
||||||
|
def setup_ddp():
|
||||||
|
"""Initializes the distributed data parallel environment."""
|
||||||
|
dist.init_process_group(backend='nccl')
|
||||||
|
rank = dist.get_rank()
|
||||||
|
local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
return rank, local_rank
|
||||||
|
|
||||||
|
def cleanup_ddp():
|
||||||
|
"""Cleans up the distributed data parallel environment."""
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
# --- Main Training Script ---
|
||||||
|
def main():
|
||||||
|
rank, local_rank = setup_ddp()
|
||||||
|
is_main_process = (rank == 0)
|
||||||
|
|
||||||
|
config = TrainConfig()
|
||||||
|
config.device = f'cuda:{local_rank}'
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.pt"
|
||||||
|
# --- 0. Save Configuration ---
|
||||||
|
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.json"
|
||||||
|
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
|
||||||
|
with open(config_filename, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
print(f"Configuration saved to {config_filename}")
|
||||||
|
|
||||||
|
# --- 1. Data Loading ---
|
||||||
|
if is_main_process:
|
||||||
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
|
|
||||||
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||||
|
if is_main_process:
|
||||||
|
print(f"Inferred vocabulary size: {vocab_size}")
|
||||||
|
|
||||||
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||||
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||||
|
|
||||||
|
train_sampler = DistributedSampler(train_dataset)
|
||||||
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
|
||||||
|
train_iter_loader = iter(itertools.cycle(train_loader))
|
||||||
|
|
||||||
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||||
|
if is_main_process:
|
||||||
|
print(f"Initializing model on {config.device}...")
|
||||||
|
model = TimeAwareGPT2(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
n_embd=config.n_embd,
|
||||||
|
n_layer=config.n_layer,
|
||||||
|
n_head=config.n_head,
|
||||||
|
pdrop=config.pdrop,
|
||||||
|
token_pdrop=config.token_pdrop
|
||||||
|
).to(config.device)
|
||||||
|
model = DDP(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay)
|
||||||
|
|
||||||
|
# --- 3. Training Loop ---
|
||||||
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
print("Starting training...")
|
||||||
|
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training", disable=not is_main_process)
|
||||||
|
for iter_num in pbar:
|
||||||
|
# --- Learning Rate Scheduling ---
|
||||||
|
if iter_num < config.warmup_iter:
|
||||||
|
lr = config.lr_initial
|
||||||
|
else:
|
||||||
|
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
|
||||||
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# --- Training Step ---
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
event_seq, time_seq = next(train_iter_loader)
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
train_losses_ce.append(loss_ce.item())
|
||||||
|
train_losses_surv.append(loss_survival.item())
|
||||||
|
train_losses_total.append(loss.item())
|
||||||
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
print("\nTraining finished.")
|
||||||
|
|
||||||
|
# --- 4. Final Validation ---
|
||||||
|
if is_main_process:
|
||||||
|
print("Running final validation...")
|
||||||
|
model.eval()
|
||||||
|
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||||
|
val_steps = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation", disable=not is_main_process)
|
||||||
|
for event_seq, time_seq in pbar_val:
|
||||||
|
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||||
|
|
||||||
|
input_events = event_seq[:, :-1]
|
||||||
|
input_times = time_seq[:, :-1]
|
||||||
|
target_events = event_seq[:, 1:]
|
||||||
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||||
|
|
||||||
|
logits = model(input_events, input_times)
|
||||||
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
|
|
||||||
|
val_loss_ce_acc += loss_ce.item()
|
||||||
|
val_loss_surv_acc += loss_survival.item()
|
||||||
|
val_steps += 1
|
||||||
|
if is_main_process:
|
||||||
|
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||||
|
|
||||||
|
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||||
|
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||||
|
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
print(f"Final Validation Summary: \n"
|
||||||
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
|
||||||
|
|
||||||
|
# --- 5. Save Model ---
|
||||||
|
print(f"Saving final model to {model_filename}")
|
||||||
|
torch.save(model.module.state_dict(), model_filename)
|
||||||
|
|
||||||
|
# --- 6. Save and Plot Losses ---
|
||||||
|
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
|
# Plot and Save Loss Curves
|
||||||
|
iterations = range(1, len(train_losses_total) + 1)
|
||||||
|
|
||||||
|
plt.figure(figsize=(18, 5))
|
||||||
|
|
||||||
|
# Plot CE Loss
|
||||||
|
plt.subplot(1, 3, 1)
|
||||||
|
plt.plot(iterations, train_losses_ce, label='Train CE')
|
||||||
|
plt.title('Cross-Entropy Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Survival Loss
|
||||||
|
plt.subplot(1, 3, 2)
|
||||||
|
plt.plot(iterations, train_losses_surv, label='Train Survival')
|
||||||
|
plt.title('Survival Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# Plot Total Loss
|
||||||
|
plt.subplot(1, 3, 3)
|
||||||
|
plt.plot(iterations, train_losses_total, label='Train Total')
|
||||||
|
plt.title('Total Loss')
|
||||||
|
plt.xlabel('Iterations')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('loss_curves_iter_multigpu.png')
|
||||||
|
print("\nLoss curves saved to loss_curves_iter_multigpu.png")
|
||||||
|
|
||||||
|
cleanup_ddp()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Reference in New Issue
Block a user