diff --git a/train.py b/train.py index a745fa3..157e449 100644 --- a/train.py +++ b/train.py @@ -122,6 +122,11 @@ def main(): loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival + # When using DataParallel, loss is a vector of losses from each GPU. + # We need to average them to get a single scalar loss. + if isinstance(model, nn.DataParallel): + loss = loss.mean() + # Backward pass and optimization optimizer.zero_grad() loss.backward()