diff --git a/train.py b/train.py index 3585741..eff1727 100644 --- a/train.py +++ b/train.py @@ -234,8 +234,8 @@ class Trainer: n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, - pretrained_emd_path=cfg.pretrained_emd_path, - freeze_pretrained_emd=True, + pretrained_weights_path=cfg.pretrained_emd_path, + freeze_embeddings=True, ).to(self.device) else: raise ValueError(f"Unsupported model type: {cfg.model_type}")