diff --git a/evaluate_models.py b/evaluate_models.py index 5f56445..8a1ff36 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -1146,6 +1146,7 @@ def predict_cifs_for_model( time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) + sexes = sexes.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) @@ -1157,7 +1158,7 @@ def predict_cifs_for_model( time_seq = time_seq[keep] cont_feats = cont_feats[keep] cate_feats = cate_feats[keep] - sexes_k = sexes[keep].to(device) + sexes_k = sexes[keep] t_ctx = t_ctx[keep] h = backbone(event_seq, time_seq, sexes_k,