Revert "fix: average loss for multi-GPU training"

This reverts commit 85502561ee.
This commit is contained in:
2025-10-16 16:23:35 +08:00
parent 85502561ee
commit 2b20299e36

View File

@@ -122,11 +122,6 @@ def main():
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# When using DataParallel, loss is a vector of losses from each GPU.
# We need to average them to get a single scalar loss.
if isinstance(model, nn.DataParallel):
loss = loss.mean()
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()