Revert "fix: average loss for multi-GPU training"
This reverts commit 85502561ee
.
This commit is contained in:
5
train.py
5
train.py
@@ -122,11 +122,6 @@ def main():
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# When using DataParallel, loss is a vector of losses from each GPU.
|
||||
# We need to average them to get a single scalar loss.
|
||||
if isinstance(model, nn.DataParallel):
|
||||
loss = loss.mean()
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
Reference in New Issue
Block a user