Fix device transfer for sexes in predict_cifs_for_model function
This commit is contained in:
@@ -1146,6 +1146,7 @@ def predict_cifs_for_model(
|
|||||||
time_seq = time_seq.to(device)
|
time_seq = time_seq.to(device)
|
||||||
cont_feats = cont_feats.to(device)
|
cont_feats = cont_feats.to(device)
|
||||||
cate_feats = cate_feats.to(device)
|
cate_feats = cate_feats.to(device)
|
||||||
|
sexes = sexes.to(device)
|
||||||
|
|
||||||
keep, t_ctx, _ = select_context_indices(
|
keep, t_ctx, _ = select_context_indices(
|
||||||
event_seq, time_seq, offset_years)
|
event_seq, time_seq, offset_years)
|
||||||
@@ -1157,7 +1158,7 @@ def predict_cifs_for_model(
|
|||||||
time_seq = time_seq[keep]
|
time_seq = time_seq[keep]
|
||||||
cont_feats = cont_feats[keep]
|
cont_feats = cont_feats[keep]
|
||||||
cate_feats = cate_feats[keep]
|
cate_feats = cate_feats[keep]
|
||||||
sexes_k = sexes[keep].to(device)
|
sexes_k = sexes[keep]
|
||||||
t_ctx = t_ctx[keep]
|
t_ctx = t_ctx[keep]
|
||||||
|
|
||||||
h = backbone(event_seq, time_seq, sexes_k,
|
h = backbone(event_seq, time_seq, sexes_k,
|
||||||
|
|||||||
Reference in New Issue
Block a user