2025-10-16 16:28:52 +08:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.optim import Adam
|
|
|
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
2025-10-16 16:46:33 +08:00
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
import numpy as np
|
|
|
|
|
import math
|
|
|
|
|
import tqdm
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import os
|
|
|
|
|
import time
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
from models import TimeAwareGPT2, CombinedLoss
|
|
|
|
|
from utils import PatientEventDataset
|
|
|
|
|
|
|
|
|
|
# --- Configuration ---
|
|
|
|
|
class TrainConfig:
|
|
|
|
|
# Data parameters
|
|
|
|
|
train_data_path = 'ukb_real_train.bin'
|
2025-10-16 16:46:33 +08:00
|
|
|
|
val_data_path = 'ukb_real_val.bin'
|
|
|
|
|
block_length = 24 # Sequence length
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
# Model parameters
|
|
|
|
|
n_embd = 256
|
|
|
|
|
n_layer = 8
|
|
|
|
|
n_head = 8
|
|
|
|
|
pdrop = 0.1
|
|
|
|
|
token_pdrop = 0.1
|
|
|
|
|
|
|
|
|
|
# Training parameters
|
|
|
|
|
max_epoch = 200
|
2025-10-16 16:46:33 +08:00
|
|
|
|
batch_size = 512 # 增大总批次大小
|
2025-10-16 16:28:52 +08:00
|
|
|
|
lr_initial = 6e-4
|
|
|
|
|
lr_final = 6e-5
|
|
|
|
|
warmup_epochs = 10
|
|
|
|
|
early_stopping_patience = 5
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
# Loss parameters
|
|
|
|
|
ignored_token_ids = [0, 1]
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# Distributed training parameters
|
|
|
|
|
world_size = torch.cuda.device_count()
|
|
|
|
|
distributed = world_size > 1
|
|
|
|
|
|
|
|
|
|
# --- Main Training Function ---
|
|
|
|
|
def train_worker(local_rank, config):
|
|
|
|
|
# Initialize distributed training
|
|
|
|
|
if config.distributed:
|
|
|
|
|
dist.init_process_group(
|
|
|
|
|
backend='nccl',
|
|
|
|
|
init_method='env://',
|
|
|
|
|
rank=local_rank,
|
|
|
|
|
world_size=config.world_size
|
|
|
|
|
)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
|
device = torch.device('cuda', local_rank)
|
2025-10-16 16:46:33 +08:00
|
|
|
|
print(f"Worker {local_rank} initialized on device {device}")
|
2025-10-16 16:28:52 +08:00
|
|
|
|
else:
|
2025-10-16 16:46:33 +08:00
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
local_rank = 0
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
# --- 1. Data Loading ---
|
2025-10-16 16:46:33 +08:00
|
|
|
|
if local_rank == 0:
|
2025-10-16 16:28:52 +08:00
|
|
|
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
2025-10-16 16:46:33 +08:00
|
|
|
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
|
|
|
|
if local_rank == 0:
|
2025-10-16 16:28:52 +08:00
|
|
|
|
print(f"Inferred vocabulary size: {vocab_size}")
|
2025-10-16 16:46:33 +08:00
|
|
|
|
print(f"Using {config.world_size} GPU(s) for training")
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
2025-10-16 16:46:33 +08:00
|
|
|
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
|
|
|
|
|
|
|
|
|
# 计算每个GPU的批次大小
|
|
|
|
|
per_gpu_batch_size = config.batch_size // config.world_size
|
|
|
|
|
|
|
|
|
|
# 优化数据加载器参数
|
|
|
|
|
if config.distributed:
|
|
|
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True)
|
|
|
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
else:
|
2025-10-16 16:46:33 +08:00
|
|
|
|
train_sampler = None
|
|
|
|
|
val_sampler = None
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# 增加num_workers,使用persistent_workers减少进程创建开销
|
2025-10-16 16:28:52 +08:00
|
|
|
|
train_loader = DataLoader(
|
2025-10-16 16:46:33 +08:00
|
|
|
|
train_dataset,
|
|
|
|
|
batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小
|
2025-10-16 16:28:52 +08:00
|
|
|
|
sampler=train_sampler,
|
2025-10-16 16:46:33 +08:00
|
|
|
|
shuffle=(train_sampler is None),
|
|
|
|
|
num_workers=8, # 增加worker数量
|
2025-10-16 16:28:52 +08:00
|
|
|
|
pin_memory=True,
|
2025-10-16 16:46:33 +08:00
|
|
|
|
persistent_workers=True, # 保持worker进程
|
|
|
|
|
prefetch_factor=2 # 预取批次
|
2025-10-16 16:28:52 +08:00
|
|
|
|
)
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
val_loader = DataLoader(
|
2025-10-16 16:46:33 +08:00
|
|
|
|
val_dataset,
|
|
|
|
|
batch_size=per_gpu_batch_size,
|
2025-10-16 16:28:52 +08:00
|
|
|
|
sampler=val_sampler,
|
2025-10-16 16:46:33 +08:00
|
|
|
|
shuffle=False,
|
|
|
|
|
num_workers=8,
|
2025-10-16 16:28:52 +08:00
|
|
|
|
pin_memory=True,
|
2025-10-16 16:46:33 +08:00
|
|
|
|
persistent_workers=True,
|
|
|
|
|
prefetch_factor=2
|
2025-10-16 16:28:52 +08:00
|
|
|
|
)
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# --- 2. Model, Optimizer, and Loss Initialization ---
|
|
|
|
|
if local_rank == 0:
|
2025-10-16 16:28:52 +08:00
|
|
|
|
print(f"Initializing model on {device}...")
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
model = TimeAwareGPT2(
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
n_embd=config.n_embd,
|
|
|
|
|
n_layer=config.n_layer,
|
|
|
|
|
n_head=config.n_head,
|
|
|
|
|
pdrop=config.pdrop,
|
|
|
|
|
token_pdrop=config.token_pdrop
|
|
|
|
|
).to(device)
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# 使用梯度累积来模拟更大的批次大小,减少通信频率
|
|
|
|
|
if config.distributed:
|
|
|
|
|
# 使用find_unused_parameters=False来加速
|
|
|
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
|
|
|
|
|
find_unused_parameters=False)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
if local_rank == 0:
|
|
|
|
|
if config.distributed:
|
|
|
|
|
num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
|
|
|
|
|
else:
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.")
|
|
|
|
|
print(f"Per GPU batch size: {per_gpu_batch_size}")
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
|
|
|
|
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
# --- 3. Training Loop ---
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
patience_counter = 0
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
|
|
|
|
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
if local_rank == 0:
|
2025-10-16 16:28:52 +08:00
|
|
|
|
print("Starting training...")
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
for epoch in range(config.max_epoch):
|
2025-10-16 16:46:33 +08:00
|
|
|
|
if config.distributed:
|
2025-10-16 16:28:52 +08:00
|
|
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# --- Learning Rate Scheduling ---
|
2025-10-16 16:28:52 +08:00
|
|
|
|
if epoch < config.warmup_epochs:
|
|
|
|
|
lr = config.lr_initial
|
|
|
|
|
else:
|
|
|
|
|
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
|
|
|
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups:
|
|
|
|
|
param_group['lr'] = lr
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# --- Training Phase ---
|
2025-10-16 16:28:52 +08:00
|
|
|
|
model.train()
|
2025-10-16 16:46:33 +08:00
|
|
|
|
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
2025-10-16 16:28:52 +08:00
|
|
|
|
train_steps = 0
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
|
|
|
|
# 只在rank 0显示进度条
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
|
|
|
|
else:
|
|
|
|
|
pbar = train_loader
|
|
|
|
|
|
|
|
|
|
batch_start_time = time.time()
|
|
|
|
|
for batch_idx, (event_seq, time_seq) in enumerate(pbar):
|
|
|
|
|
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# Prepare inputs and targets
|
2025-10-16 16:28:52 +08:00
|
|
|
|
input_events = event_seq[:, :-1]
|
|
|
|
|
input_times = time_seq[:, :-1]
|
|
|
|
|
target_events = event_seq[:, 1:]
|
|
|
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# Forward pass
|
|
|
|
|
logits = model(input_events, input_times)
|
|
|
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
|
|
|
loss = loss_ce + loss_survival
|
|
|
|
|
|
|
|
|
|
# Backward pass and optimization
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
# 梯度同步在DDP中自动处理
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
# 异步记录损失,避免同步阻塞
|
|
|
|
|
train_loss_ce_acc += loss_ce.item()
|
|
|
|
|
train_loss_surv_acc += loss_survival.item()
|
2025-10-16 16:28:52 +08:00
|
|
|
|
train_steps += 1
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
|
|
|
|
if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次
|
|
|
|
|
batch_time = time.time() - batch_start_time
|
|
|
|
|
pbar.set_postfix({
|
|
|
|
|
'loss_ce': f'{loss_ce.item():.4f}',
|
|
|
|
|
'loss_surv': f'{loss_survival.item():.4f}',
|
|
|
|
|
'lr': f'{lr:.2e}',
|
|
|
|
|
'batch_time': f'{batch_time:.3f}s'
|
|
|
|
|
})
|
|
|
|
|
batch_start_time = time.time()
|
|
|
|
|
|
|
|
|
|
# 只在epoch结束时同步一次损失,减少通信
|
|
|
|
|
if config.distributed:
|
|
|
|
|
# 使用all_reduce同步损失
|
|
|
|
|
train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device)
|
|
|
|
|
train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device)
|
|
|
|
|
train_steps_tensor = torch.tensor([train_steps], device=device)
|
|
|
|
|
|
|
|
|
|
dist.all_reduce(train_loss_ce_tensor)
|
|
|
|
|
dist.all_reduce(train_loss_surv_tensor)
|
|
|
|
|
dist.all_reduce(train_steps_tensor)
|
|
|
|
|
|
|
|
|
|
avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item())
|
|
|
|
|
avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item())
|
2025-10-16 16:28:52 +08:00
|
|
|
|
else:
|
2025-10-16 16:46:33 +08:00
|
|
|
|
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
|
|
|
|
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# --- Validation Phase ---
|
2025-10-16 16:28:52 +08:00
|
|
|
|
model.eval()
|
2025-10-16 16:46:33 +08:00
|
|
|
|
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
2025-10-16 16:28:52 +08:00
|
|
|
|
val_steps = 0
|
2025-10-16 16:46:33 +08:00
|
|
|
|
|
2025-10-16 16:28:52 +08:00
|
|
|
|
with torch.no_grad():
|
2025-10-16 16:46:33 +08:00
|
|
|
|
for event_seq, time_seq in val_loader:
|
|
|
|
|
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
input_events = event_seq[:, :-1]
|
|
|
|
|
input_times = time_seq[:, :-1]
|
|
|
|
|
target_events = event_seq[:, 1:]
|
|
|
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
logits = model(input_events, input_times)
|
|
|
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
|
|
|
|
|
|
|
|
val_loss_ce_acc += loss_ce.item()
|
|
|
|
|
val_loss_surv_acc += loss_survival.item()
|
2025-10-16 16:28:52 +08:00
|
|
|
|
val_steps += 1
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# 同步验证损失
|
|
|
|
|
if config.distributed:
|
|
|
|
|
val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device)
|
|
|
|
|
val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device)
|
|
|
|
|
val_steps_tensor = torch.tensor([val_steps], device=device)
|
|
|
|
|
|
|
|
|
|
dist.all_reduce(val_loss_ce_tensor)
|
|
|
|
|
dist.all_reduce(val_loss_surv_tensor)
|
|
|
|
|
dist.all_reduce(val_steps_tensor)
|
|
|
|
|
|
|
|
|
|
avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item())
|
|
|
|
|
avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item())
|
|
|
|
|
else:
|
|
|
|
|
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
|
|
|
|
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
|
|
|
|
|
|
|
|
|
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
|
|
|
|
|
|
|
|
|
# 只在rank 0进行打印和保存
|
|
|
|
|
if local_rank == 0:
|
|
|
|
|
train_losses_ce.append(avg_train_loss_ce)
|
|
|
|
|
train_losses_surv.append(avg_train_loss_surv)
|
|
|
|
|
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
|
|
|
|
|
val_losses_ce.append(avg_val_loss_ce)
|
|
|
|
|
val_losses_surv.append(avg_val_loss_surv)
|
|
|
|
|
val_losses_total.append(total_val_loss)
|
|
|
|
|
|
|
|
|
|
print(f"Epoch {epoch+1} Summary: \n"
|
|
|
|
|
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
|
|
|
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
2025-10-16 16:28:52 +08:00
|
|
|
|
f" Learning Rate: {lr:.6f}")
|
|
|
|
|
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# Early stopping check
|
|
|
|
|
if total_val_loss < best_val_loss:
|
|
|
|
|
best_val_loss = total_val_loss
|
2025-10-16 16:28:52 +08:00
|
|
|
|
patience_counter = 0
|
2025-10-16 16:46:33 +08:00
|
|
|
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
|
|
|
|
if config.distributed:
|
|
|
|
|
torch.save(model.module.state_dict(), 'best_model_checkpoint.pt')
|
|
|
|
|
else:
|
|
|
|
|
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
2025-10-16 16:28:52 +08:00
|
|
|
|
else:
|
|
|
|
|
if epoch >= config.warmup_epochs:
|
|
|
|
|
patience_counter += 1
|
|
|
|
|
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
|
|
|
|
|
|
|
|
|
if patience_counter >= config.early_stopping_patience:
|
|
|
|
|
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
2025-10-16 16:46:33 +08:00
|
|
|
|
if config.distributed:
|
|
|
|
|
stop_signal = torch.tensor(1, device=device)
|
|
|
|
|
dist.broadcast(stop_signal, 0)
|
|
|
|
|
break
|
2025-10-16 16:28:52 +08:00
|
|
|
|
else:
|
2025-10-16 16:46:33 +08:00
|
|
|
|
# 非rank 0进程检查停止信号
|
|
|
|
|
if config.distributed:
|
|
|
|
|
stop_signal = torch.tensor(0, device=device)
|
|
|
|
|
dist.broadcast(stop_signal, 0)
|
|
|
|
|
if stop_signal.item() == 1:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 清理和保存
|
|
|
|
|
if local_rank == 0 and best_val_loss != float('inf'):
|
|
|
|
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
|
|
|
|
if config.distributed:
|
|
|
|
|
model.module.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
|
|
|
|
torch.save(model.module.state_dict(), 'best_model.pt')
|
|
|
|
|
else:
|
|
|
|
|
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
|
|
|
|
torch.save(model.state_dict(), 'best_model.pt')
|
|
|
|
|
print("Final best model saved to best_model.pt")
|
|
|
|
|
|
|
|
|
|
if config.distributed:
|
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
config = TrainConfig()
|
|
|
|
|
|
|
|
|
|
# 设置环境变量优化
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步
|
|
|
|
|
os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志
|
|
|
|
|
os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口
|
|
|
|
|
|
|
|
|
|
if config.distributed:
|
|
|
|
|
print(f"Starting distributed training with {config.world_size} GPUs")
|
|
|
|
|
mp.spawn(
|
|
|
|
|
train_worker,
|
|
|
|
|
args=(config,),
|
|
|
|
|
nprocs=config.world_size,
|
|
|
|
|
join=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
print("Starting single GPU training")
|
|
|
|
|
train_worker(0, config)
|
2025-10-16 16:28:52 +08:00
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2025-10-16 16:46:33 +08:00
|
|
|
|
main()
|