Fix device transfer for sexes in predict_cifs_for_model function
This commit is contained in:
@@ -1146,6 +1146,7 @@ def predict_cifs_for_model(
|
||||
time_seq = time_seq.to(device)
|
||||
cont_feats = cont_feats.to(device)
|
||||
cate_feats = cate_feats.to(device)
|
||||
sexes = sexes.to(device)
|
||||
|
||||
keep, t_ctx, _ = select_context_indices(
|
||||
event_seq, time_seq, offset_years)
|
||||
@@ -1157,7 +1158,7 @@ def predict_cifs_for_model(
|
||||
time_seq = time_seq[keep]
|
||||
cont_feats = cont_feats[keep]
|
||||
cate_feats = cate_feats[keep]
|
||||
sexes_k = sexes[keep].to(device)
|
||||
sexes_k = sexes[keep]
|
||||
t_ctx = t_ctx[keep]
|
||||
|
||||
h = backbone(event_seq, time_seq, sexes_k,
|
||||
|
||||
Reference in New Issue
Block a user