Add n_tech_tokens parameter to DelphiFork and SapDelphi model initializations

This commit is contained in:
2026-01-08 11:36:23 +08:00
parent 1d1f568a3f
commit 9eda00ea48

View File

@@ -174,6 +174,7 @@ class Trainer:
if cfg.model_type == "delphi_fork": if cfg.model_type == "delphi_fork":
self.model = DelphiFork( self.model = DelphiFork(
n_disease=dataset.n_disease, n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd, n_embd=cfg.n_embd,
n_head=cfg.n_head, n_head=cfg.n_head,
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
@@ -182,10 +183,12 @@ class Trainer:
n_dim=n_dim, n_dim=n_dim,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device) ).to(self.device)
elif cfg.model_type == "sap_delphi": elif cfg.model_type == "sap_delphi":
self.model = SapDelphi( self.model = SapDelphi(
n_disease=dataset.n_disease, n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd, n_embd=cfg.n_embd,
n_head=cfg.n_head, n_head=cfg.n_head,
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
@@ -194,6 +197,7 @@ class Trainer:
n_dim=n_dim, n_dim=n_dim,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_emd_path=cfg.pretrained_emd_path, pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True, freeze_pretrained_emd=True,
).to(self.device) ).to(self.device)
@@ -208,6 +212,7 @@ class Trainer:
if cfg.model_type == "delphi_fork": if cfg.model_type == "delphi_fork":
self.ema_model = DelphiFork( self.ema_model = DelphiFork(
n_disease=dataset.n_disease, n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd, n_embd=cfg.n_embd,
n_head=cfg.n_head, n_head=cfg.n_head,
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
@@ -216,10 +221,12 @@ class Trainer:
n_dim=n_dim, n_dim=n_dim,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device) ).to(self.device)
elif cfg.model_type == "sap_delphi": elif cfg.model_type == "sap_delphi":
self.ema_model = SapDelphi( self.ema_model = SapDelphi(
n_disease=dataset.n_disease, n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd, n_embd=cfg.n_embd,
n_head=cfg.n_head, n_head=cfg.n_head,
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
@@ -228,6 +235,7 @@ class Trainer:
n_dim=n_dim, n_dim=n_dim,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_emd_path=cfg.pretrained_emd_path, pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True, freeze_pretrained_emd=True,
).to(self.device) ).to(self.device)