Add n_tech_tokens parameter to DelphiFork and SapDelphi model initializations
This commit is contained in:
8
train.py
8
train.py
@@ -174,6 +174,7 @@ class Trainer:
|
||||
if cfg.model_type == "delphi_fork":
|
||||
self.model = DelphiFork(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
@@ -182,10 +183,12 @@ class Trainer:
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
).to(self.device)
|
||||
elif cfg.model_type == "sap_delphi":
|
||||
self.model = SapDelphi(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
@@ -194,6 +197,7 @@ class Trainer:
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||
freeze_pretrained_emd=True,
|
||||
).to(self.device)
|
||||
@@ -208,6 +212,7 @@ class Trainer:
|
||||
if cfg.model_type == "delphi_fork":
|
||||
self.ema_model = DelphiFork(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
@@ -216,10 +221,12 @@ class Trainer:
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
).to(self.device)
|
||||
elif cfg.model_type == "sap_delphi":
|
||||
self.ema_model = SapDelphi(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
@@ -228,6 +235,7 @@ class Trainer:
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||
freeze_pretrained_emd=True,
|
||||
).to(self.device)
|
||||
|
||||
Reference in New Issue
Block a user