From 85502561eee658b8bf8f53d85b560dc30d558538 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 16 Oct 2025 16:21:51 +0800 Subject: [PATCH] fix: average loss for multi-GPU training --- train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train.py b/train.py index a745fa3..157e449 100644 --- a/train.py +++ b/train.py @@ -122,6 +122,11 @@ def main(): loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival + # When using DataParallel, loss is a vector of losses from each GPU. + # We need to average them to get a single scalar loss. + if isinstance(model, nn.DataParallel): + loss = loss.mean() + # Backward pass and optimization optimizer.zero_grad() loss.backward()