diff --git a/train.py b/train.py index 9d54ac8..5291523 100644 --- a/train.py +++ b/train.py @@ -178,7 +178,7 @@ class Trainer: n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, - age_encoder=cfg.age_encoder, + age_encoder_type=cfg.age_encoder, n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, @@ -190,7 +190,7 @@ class Trainer: n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, - age_encoder=cfg.age_encoder, + age_encoder_type=cfg.age_encoder, n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, @@ -212,7 +212,7 @@ class Trainer: n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, - age_encoder=cfg.age_encoder, + age_encoder_type=cfg.age_encoder, n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, @@ -224,7 +224,7 @@ class Trainer: n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, - age_encoder=cfg.age_encoder, + age_encoder_type=cfg.age_encoder, n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate,