Fix device transfer for sexes in predict_cifs_for_model function

This commit is contained in:
2026-01-10 17:02:28 +08:00
parent 029f147ab5
commit 87baef3ecf

View File

@@ -1146,6 +1146,7 @@ def predict_cifs_for_model(
time_seq = time_seq.to(device) time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device) cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device) cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices( keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years) event_seq, time_seq, offset_years)
@@ -1157,7 +1158,7 @@ def predict_cifs_for_model(
time_seq = time_seq[keep] time_seq = time_seq[keep]
cont_feats = cont_feats[keep] cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep] cate_feats = cate_feats[keep]
sexes_k = sexes[keep].to(device) sexes_k = sexes[keep]
t_ctx = t_ctx[keep] t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k, h = backbone(event_seq, time_seq, sexes_k,