Update age_encoder parameter choices in TrainConfig and argument parser for clarity
This commit is contained in:
4
train.py
4
train.py
@@ -25,7 +25,7 @@ class TrainConfig:
|
|||||||
# Model Parameters
|
# Model Parameters
|
||||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||||
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
loss_type: Literal['exponential', 'weibull'] = 'weibull'
|
||||||
age_encoder: Literal['sinusoidal', 'learned'] = 'learned'
|
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||||
full_cov: bool = False
|
full_cov: bool = False
|
||||||
n_embd: int = 120
|
n_embd: int = 120
|
||||||
n_head: int = 12
|
n_head: int = 12
|
||||||
@@ -60,7 +60,7 @@ def parse_args() -> TrainConfig:
|
|||||||
parser.add_argument("--loss_type", type=str, choices=[
|
parser.add_argument("--loss_type", type=str, choices=[
|
||||||
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
'exponential', 'weibull'], default='weibull', help="Type of loss function to use.")
|
||||||
parser.add_argument("--age_encoder", type=str, choices=[
|
parser.add_argument("--age_encoder", type=str, choices=[
|
||||||
'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.")
|
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||||
parser.add_argument("--n_embd", type=int, default=120,
|
parser.add_argument("--n_embd", type=int, default=120,
|
||||||
help="Embedding dimension.")
|
help="Embedding dimension.")
|
||||||
parser.add_argument("--n_head", type=int, default=12,
|
parser.add_argument("--n_head", type=int, default=12,
|
||||||
|
|||||||
Reference in New Issue
Block a user