Add n_tech_tokens parameter to DelphiFork and SapDelphi model initializations
This commit is contained in:
8
train.py
8
train.py
@@ -174,6 +174,7 @@ class Trainer:
|
|||||||
if cfg.model_type == "delphi_fork":
|
if cfg.model_type == "delphi_fork":
|
||||||
self.model = DelphiFork(
|
self.model = DelphiFork(
|
||||||
n_disease=dataset.n_disease,
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
n_embd=cfg.n_embd,
|
n_embd=cfg.n_embd,
|
||||||
n_head=cfg.n_head,
|
n_head=cfg.n_head,
|
||||||
n_layer=cfg.n_layer,
|
n_layer=cfg.n_layer,
|
||||||
@@ -182,10 +183,12 @@ class Trainer:
|
|||||||
n_dim=n_dim,
|
n_dim=n_dim,
|
||||||
n_cont=dataset.n_cont,
|
n_cont=dataset.n_cont,
|
||||||
n_cate=dataset.n_cate,
|
n_cate=dataset.n_cate,
|
||||||
|
cate_dims=dataset.cate_dims,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
elif cfg.model_type == "sap_delphi":
|
elif cfg.model_type == "sap_delphi":
|
||||||
self.model = SapDelphi(
|
self.model = SapDelphi(
|
||||||
n_disease=dataset.n_disease,
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
n_embd=cfg.n_embd,
|
n_embd=cfg.n_embd,
|
||||||
n_head=cfg.n_head,
|
n_head=cfg.n_head,
|
||||||
n_layer=cfg.n_layer,
|
n_layer=cfg.n_layer,
|
||||||
@@ -194,6 +197,7 @@ class Trainer:
|
|||||||
n_dim=n_dim,
|
n_dim=n_dim,
|
||||||
n_cont=dataset.n_cont,
|
n_cont=dataset.n_cont,
|
||||||
n_cate=dataset.n_cate,
|
n_cate=dataset.n_cate,
|
||||||
|
cate_dims=dataset.cate_dims,
|
||||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||||
freeze_pretrained_emd=True,
|
freeze_pretrained_emd=True,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
@@ -208,6 +212,7 @@ class Trainer:
|
|||||||
if cfg.model_type == "delphi_fork":
|
if cfg.model_type == "delphi_fork":
|
||||||
self.ema_model = DelphiFork(
|
self.ema_model = DelphiFork(
|
||||||
n_disease=dataset.n_disease,
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
n_embd=cfg.n_embd,
|
n_embd=cfg.n_embd,
|
||||||
n_head=cfg.n_head,
|
n_head=cfg.n_head,
|
||||||
n_layer=cfg.n_layer,
|
n_layer=cfg.n_layer,
|
||||||
@@ -216,10 +221,12 @@ class Trainer:
|
|||||||
n_dim=n_dim,
|
n_dim=n_dim,
|
||||||
n_cont=dataset.n_cont,
|
n_cont=dataset.n_cont,
|
||||||
n_cate=dataset.n_cate,
|
n_cate=dataset.n_cate,
|
||||||
|
cate_dims=dataset.cate_dims,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
elif cfg.model_type == "sap_delphi":
|
elif cfg.model_type == "sap_delphi":
|
||||||
self.ema_model = SapDelphi(
|
self.ema_model = SapDelphi(
|
||||||
n_disease=dataset.n_disease,
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
n_embd=cfg.n_embd,
|
n_embd=cfg.n_embd,
|
||||||
n_head=cfg.n_head,
|
n_head=cfg.n_head,
|
||||||
n_layer=cfg.n_layer,
|
n_layer=cfg.n_layer,
|
||||||
@@ -228,6 +235,7 @@ class Trainer:
|
|||||||
n_dim=n_dim,
|
n_dim=n_dim,
|
||||||
n_cont=dataset.n_cont,
|
n_cont=dataset.n_cont,
|
||||||
n_cate=dataset.n_cate,
|
n_cate=dataset.n_cate,
|
||||||
|
cate_dims=dataset.cate_dims,
|
||||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||||
freeze_pretrained_emd=True,
|
freeze_pretrained_emd=True,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user