Update age_encoder parameter choices in TrainConfig and argument parser for clarity
This commit is contained in:
4
train.py
4
train.py
@@ -25,7 +25,7 @@ class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
||||
age_encoder: Literal['sinusoidal', 'learned'] = 'learned'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
n_head: int = 12
|
||||
@@ -60,7 +60,7 @@ def parse_args() -> TrainConfig:
|
||||
parser.add_argument("--loss_type", type=str, choices=[
|
||||
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.")
|
||||
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||
parser.add_argument("--n_embd", type=int, default=120,
|
||||
help="Embedding dimension.")
|
||||
parser.add_argument("--n_head", type=int, default=12,
|
||||
|
||||
Reference in New Issue
Block a user