Update Trainer class to rename parameters for clarity in embedding configuration
This commit is contained in:
4
train.py
4
train.py
@@ -234,8 +234,8 @@ class Trainer:
|
|||||||
n_cont=dataset.n_cont,
|
n_cont=dataset.n_cont,
|
||||||
n_cate=dataset.n_cate,
|
n_cate=dataset.n_cate,
|
||||||
cate_dims=dataset.cate_dims,
|
cate_dims=dataset.cate_dims,
|
||||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
pretrained_weights_path=cfg.pretrained_emd_path,
|
||||||
freeze_pretrained_emd=True,
|
freeze_embeddings=True,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||||
|
|||||||
Reference in New Issue
Block a user