From 7c36f7a0075b4ad46b501f3c2809990453fbf739 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 8 Jan 2026 11:38:45 +0800 Subject: [PATCH] Update age_encoder parameter choices in TrainConfig and argument parser for clarity --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 9d10152..e3d1343 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,7 @@ class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' loss_type: Literal['exponential', 'weibull'] = 'weibull' - age_encoder: Literal['sinusoidal', 'learned'] = 'learned' + age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 n_head: int = 12 @@ -60,7 +60,7 @@ def parse_args() -> TrainConfig: parser.add_argument("--loss_type", type=str, choices=[ 'exponential', 'weibull'], default='weibull', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ - 'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.") + 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") parser.add_argument("--n_embd", type=int, default=120, help="Embedding dimension.") parser.add_argument("--n_head", type=int, default=12,