diff --git a/train.py b/train.py index 5291523..9d10152 100644 --- a/train.py +++ b/train.py @@ -174,6 +174,7 @@ class Trainer: if cfg.model_type == "delphi_fork": self.model = DelphiFork( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -182,10 +183,12 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, ).to(self.device) elif cfg.model_type == "sap_delphi": self.model = SapDelphi( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -194,6 +197,7 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, pretrained_emd_path=cfg.pretrained_emd_path, freeze_pretrained_emd=True, ).to(self.device) @@ -208,6 +212,7 @@ class Trainer: if cfg.model_type == "delphi_fork": self.ema_model = DelphiFork( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -216,10 +221,12 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, ).to(self.device) elif cfg.model_type == "sap_delphi": self.ema_model = SapDelphi( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -228,6 +235,7 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, pretrained_emd_path=cfg.pretrained_emd_path, freeze_pretrained_emd=True, ).to(self.device)