Compare commits

...

3 Commits

Author SHA1 Message Date
7e8d8d307b chore: Ignore small data files 2025-10-17 10:34:24 +08:00
fc0aef4e71 chore: Add .gitignore 2025-10-17 10:32:42 +08:00
02d84a7eca refactor: Use AdamW optimizer and increase early stopping patience 2025-10-17 10:31:12 +08:00
3 changed files with 24 additions and 3 deletions

17
.gitignore vendored Normal file
View File

@@ -0,0 +1,17 @@
# IDE settings
.idea/
# Python cache
__pycache__/
# Model checkpoints
best_model_checkpoint.pt
# Large data files
ukb_delphi.txt
ukb_real.bin
# Small data files
fields.txt
icd10_codes_mod.tsv
labels.csv

4
requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
torch
numpy
tqdm
matplotlib

View File

@@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import AdamW
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
import math import math
@@ -30,7 +30,7 @@ class TrainConfig:
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 10
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event" # 0 = padding, 1 = "no event"
@@ -72,7 +72,7 @@ def main():
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = AdamW(model.parameters(), lr=config.lr_initial)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')