fix: average loss for multi-GPU training
This commit is contained in:
5
train.py
5
train.py
@@ -122,6 +122,11 @@ def main():
|
|||||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||||
loss = loss_ce + loss_survival
|
loss = loss_ce + loss_survival
|
||||||
|
|
||||||
|
# When using DataParallel, loss is a vector of losses from each GPU.
|
||||||
|
# We need to average them to get a single scalar loss.
|
||||||
|
if isinstance(model, nn.DataParallel):
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Reference in New Issue
Block a user