From dc34d51864a872fbda02cb6d9598a987a5f1acfc Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 13:18:09 +0800 Subject: [PATCH] Add rank parameter to TrainConfig and update argument parsing for low-rank parameterization --- train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train.py b/train.py index eef93fd..5d37deb 100644 --- a/train.py +++ b/train.py @@ -34,6 +34,7 @@ class TrainConfig: bin_edges: Sequence[float] = field( default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0] ) + rank: int = 16 # SapDelphi specific pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" # Data Parameters @@ -74,6 +75,8 @@ def parse_args() -> TrainConfig: help="Dropout probability.") parser.add_argument("--lambda_reg", type=float, default=1e-4, help="Regularization weight.") + parser.add_argument("--rank", type=int, default=16, + help="Rank for low-rank parameterization (if applicable).") parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy", help="Path to pretrained embeddings for SapDelphi.") parser.add_argument("--data_prefix", type=str, @@ -215,6 +218,7 @@ class Trainer: pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, n_dim=n_dim, + rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, @@ -229,6 +233,7 @@ class Trainer: pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, n_dim=n_dim, + rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims,