From 9eda00ea48758fcc35e3c24f04699873e134c48b Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 8 Jan 2026 11:36:23 +0800 Subject: [PATCH] Add n_tech_tokens parameter to DelphiFork and SapDelphi model initializations --- train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train.py b/train.py index 5291523..9d10152 100644 --- a/train.py +++ b/train.py @@ -174,6 +174,7 @@ class Trainer: if cfg.model_type == "delphi_fork": self.model = DelphiFork( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -182,10 +183,12 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, ).to(self.device) elif cfg.model_type == "sap_delphi": self.model = SapDelphi( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -194,6 +197,7 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, pretrained_emd_path=cfg.pretrained_emd_path, freeze_pretrained_emd=True, ).to(self.device) @@ -208,6 +212,7 @@ class Trainer: if cfg.model_type == "delphi_fork": self.ema_model = DelphiFork( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -216,10 +221,12 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, ).to(self.device) elif cfg.model_type == "sap_delphi": self.ema_model = SapDelphi( n_disease=dataset.n_disease, + n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, @@ -228,6 +235,7 @@ class Trainer: n_dim=n_dim, n_cont=dataset.n_cont, n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, pretrained_emd_path=cfg.pretrained_emd_path, freeze_pretrained_emd=True, ).to(self.device)