Update Trainer class to rename parameters for clarity in embedding configuration
This commit is contained in:
4
train.py
4
train.py
@@ -234,8 +234,8 @@ class Trainer:
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||
freeze_pretrained_emd=True,
|
||||
pretrained_weights_path=cfg.pretrained_emd_path,
|
||||
freeze_embeddings=True,
|
||||
).to(self.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||
|
||||
Reference in New Issue
Block a user