From 2b20299e369118cff35ee8f3de73b7b5a58632c8 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 16 Oct 2025 16:23:35 +0800 Subject: [PATCH] Revert "fix: average loss for multi-GPU training" This reverts commit 85502561eee658b8bf8f53d85b560dc30d558538. --- train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/train.py b/train.py index 157e449..a745fa3 100644 --- a/train.py +++ b/train.py @@ -122,11 +122,6 @@ def main(): loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival - # When using DataParallel, loss is a vector of losses from each GPU. - # We need to average them to get a single scalar loss. - if isinstance(model, nn.DataParallel): - loss = loss.mean() - # Backward pass and optimization optimizer.zero_grad() loss.backward()