diff --git a/train.py b/train.py index 157e449..a745fa3 100644 --- a/train.py +++ b/train.py @@ -122,11 +122,6 @@ def main(): loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival - # When using DataParallel, loss is a vector of losses from each GPU. - # We need to average them to get a single scalar loss. - if isinstance(model, nn.DataParallel): - loss = loss.mean() - # Backward pass and optimization optimizer.zero_grad() loss.backward()