Add rank parameter to TrainConfig and update argument parsing for low-rank parameterization
This commit is contained in:
5
train.py
5
train.py
@@ -34,6 +34,7 @@ class TrainConfig:
|
||||
bin_edges: Sequence[float] = field(
|
||||
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
|
||||
)
|
||||
rank: int = 16
|
||||
# SapDelphi specific
|
||||
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
||||
# Data Parameters
|
||||
@@ -74,6 +75,8 @@ def parse_args() -> TrainConfig:
|
||||
help="Dropout probability.")
|
||||
parser.add_argument("--lambda_reg", type=float,
|
||||
default=1e-4, help="Regularization weight.")
|
||||
parser.add_argument("--rank", type=int, default=16,
|
||||
help="Rank for low-rank parameterization (if applicable).")
|
||||
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
|
||||
help="Path to pretrained embeddings for SapDelphi.")
|
||||
parser.add_argument("--data_prefix", type=str,
|
||||
@@ -215,6 +218,7 @@ class Trainer:
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
@@ -229,6 +233,7 @@ class Trainer:
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
|
||||
Reference in New Issue
Block a user