Fix device transfer for sexes in predict_cifs_for_model function

This commit is contained in:
2026-01-10 17:02:28 +08:00
parent 029f147ab5
commit 87baef3ecf

View File

@@ -1146,6 +1146,7 @@ def predict_cifs_for_model(
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
@@ -1157,7 +1158,7 @@ def predict_cifs_for_model(
time_seq = time_seq[keep]
cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep]
sexes_k = sexes[keep].to(device)
sexes_k = sexes[keep]
t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k,