Compare commits
3 Commits
cb7575a229
...
7e8d8d307b
Author | SHA1 | Date | |
---|---|---|---|
7e8d8d307b | |||
fc0aef4e71 | |||
02d84a7eca |
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# IDE settings
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Python cache
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Model checkpoints
|
||||||
|
best_model_checkpoint.pt
|
||||||
|
|
||||||
|
# Large data files
|
||||||
|
ukb_delphi.txt
|
||||||
|
ukb_real.bin
|
||||||
|
|
||||||
|
# Small data files
|
||||||
|
fields.txt
|
||||||
|
icd10_codes_mod.tsv
|
||||||
|
labels.csv
|
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
torch
|
||||||
|
numpy
|
||||||
|
tqdm
|
||||||
|
matplotlib
|
6
train.py
6
train.py
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Adam
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
@@ -30,7 +30,7 @@ class TrainConfig:
|
|||||||
lr_initial = 6e-4
|
lr_initial = 6e-4
|
||||||
lr_final = 6e-5
|
lr_final = 6e-5
|
||||||
warmup_epochs = 10
|
warmup_epochs = 10
|
||||||
early_stopping_patience = 5
|
early_stopping_patience = 10
|
||||||
|
|
||||||
# Loss parameters
|
# Loss parameters
|
||||||
# 0 = padding, 1 = "no event"
|
# 0 = padding, 1 = "no event"
|
||||||
@@ -72,7 +72,7 @@ def main():
|
|||||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||||
|
|
||||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
optimizer = AdamW(model.parameters(), lr=config.lr_initial)
|
||||||
|
|
||||||
# --- 3. Training Loop ---
|
# --- 3. Training Loop ---
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
Reference in New Issue
Block a user