Update age_encoder parameter choices in TrainConfig and argument parser for clarity

This commit is contained in:
2026-01-08 11:38:45 +08:00
parent 9eda00ea48
commit 7c36f7a007

View File

@@ -25,7 +25,7 @@ class TrainConfig:
# Model Parameters # Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull'] = 'weibull' loss_type: Literal['exponential', 'weibull'] = 'weibull'
age_encoder: Literal['sinusoidal', 'learned'] = 'learned' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False full_cov: bool = False
n_embd: int = 120 n_embd: int = 120
n_head: int = 12 n_head: int = 12
@@ -60,7 +60,7 @@ def parse_args() -> TrainConfig:
parser.add_argument("--loss_type", type=str, choices=[ parser.add_argument("--loss_type", type=str, choices=[
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.") 'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[ parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.") 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120, parser.add_argument("--n_embd", type=int, default=120,
help="Embedding dimension.") help="Embedding dimension.")
parser.add_argument("--n_head", type=int, default=12, parser.add_argument("--n_head", type=int, default=12,