Add rank parameter to TrainConfig and update argument parsing for low-rank parameterization

This commit is contained in:
2026-01-09 13:18:09 +08:00
parent 1fa6d55d79
commit dc34d51864

View File

@@ -34,6 +34,7 @@ class TrainConfig:
bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
)
rank: int = 16
# SapDelphi specific
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
# Data Parameters
@@ -74,6 +75,8 @@ def parse_args() -> TrainConfig:
help="Dropout probability.")
parser.add_argument("--lambda_reg", type=float,
default=1e-4, help="Regularization weight.")
parser.add_argument("--rank", type=int, default=16,
help="Rank for low-rank parameterization (if applicable).")
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
help="Path to pretrained embeddings for SapDelphi.")
parser.add_argument("--data_prefix", type=str,
@@ -215,6 +218,7 @@ class Trainer:
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
@@ -229,6 +233,7 @@ class Trainer:
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,