Add n_tech_tokens parameter to DelphiFork and SapDelphi model initializations

This commit is contained in:
2026-01-08 11:36:23 +08:00
parent 1d1f568a3f
commit 9eda00ea48

View File

@@ -174,6 +174,7 @@ class Trainer:
if cfg.model_type == "delphi_fork":
self.model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
@@ -182,10 +183,12 @@ class Trainer:
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
@@ -194,6 +197,7 @@ class Trainer:
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True,
).to(self.device)
@@ -208,6 +212,7 @@ class Trainer:
if cfg.model_type == "delphi_fork":
self.ema_model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
@@ -216,10 +221,12 @@ class Trainer:
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.ema_model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
@@ -228,6 +235,7 @@ class Trainer:
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True,
).to(self.device)