fix: average loss for multi-GPU training

This commit is contained in:
2025-10-16 16:21:51 +08:00
parent b7aad7a774
commit 85502561ee

View File

@@ -122,6 +122,11 @@ def main():
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# When using DataParallel, loss is a vector of losses from each GPU.
# We need to average them to get a single scalar loss.
if isinstance(model, nn.DataParallel):
loss = loss.mean()
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()