Enhance DataLoader configuration and improve tensor transfer efficiency in Trainer class
This commit is contained in:
@@ -46,10 +46,11 @@ class SelfAttention(nn.Module):
|
|||||||
k = reshape_heads(k)
|
k = reshape_heads(k)
|
||||||
v = reshape_heads(v)
|
v = reshape_heads(v)
|
||||||
|
|
||||||
|
dropout_p = self.attn_pdrop if self.training else 0.0
|
||||||
attn = F.scaled_dot_product_attention(
|
attn = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_pdrop,
|
dropout_p=dropout_p,
|
||||||
) # (B, H, L, d)
|
) # (B, H, L, d)
|
||||||
|
|
||||||
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
||||||
|
|||||||
26
dataset.py
26
dataset.py
@@ -35,6 +35,8 @@ class HealthDataset(Dataset):
|
|||||||
self.basic_info = basic_info.convert_dtypes()
|
self.basic_info = basic_info.convert_dtypes()
|
||||||
self.patient_ids = self.basic_info.index.tolist()
|
self.patient_ids = self.basic_info.index.tolist()
|
||||||
self.patient_events = dict(patient_events)
|
self.patient_events = dict(patient_events)
|
||||||
|
for patient_id, records in self.patient_events.items():
|
||||||
|
records.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
tabular_data = tabular_data.convert_dtypes()
|
tabular_data = tabular_data.convert_dtypes()
|
||||||
cont_cols = []
|
cont_cols = []
|
||||||
@@ -61,17 +63,26 @@ class HealthDataset(Dataset):
|
|||||||
self.n_cont = self.cont_features.shape[1]
|
self.n_cont = self.cont_features.shape[1]
|
||||||
self.n_cate = self.cate_features.shape[1]
|
self.n_cate = self.cate_features.shape[1]
|
||||||
|
|
||||||
|
self._doa = self.basic_info.loc[
|
||||||
|
self.patient_ids, 'date_of_assessment'
|
||||||
|
].to_numpy(dtype=np.float32)
|
||||||
|
self._sex = self.basic_info.loc[
|
||||||
|
self.patient_ids, 'sex'
|
||||||
|
].to_numpy(dtype=np.int64)
|
||||||
|
|
||||||
|
self.cont_features = torch.from_numpy(self.cont_features)
|
||||||
|
self.cate_features = torch.from_numpy(self.cate_features)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.patient_ids)
|
return len(self.patient_ids)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
patient_id = self.patient_ids[idx]
|
patient_id = self.patient_ids[idx]
|
||||||
records = sorted(self.patient_events.get(
|
records = self.patient_events.get(patient_id, [])
|
||||||
patient_id, []), key=lambda x: x[0])
|
|
||||||
event_seq = [item[1] for item in records]
|
event_seq = [item[1] for item in records]
|
||||||
time_seq = [item[0] for item in records]
|
time_seq = [item[0] for item in records]
|
||||||
|
|
||||||
doa = self.basic_info.loc[patient_id, 'date_of_assessment']
|
doa = float(self._doa[idx])
|
||||||
|
|
||||||
insert_pos = np.searchsorted(time_seq, doa)
|
insert_pos = np.searchsorted(time_seq, doa)
|
||||||
time_seq.insert(insert_pos, doa)
|
time_seq.insert(insert_pos, doa)
|
||||||
@@ -79,11 +90,10 @@ class HealthDataset(Dataset):
|
|||||||
event_seq.insert(insert_pos, 1)
|
event_seq.insert(insert_pos, 1)
|
||||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||||
cont_tensor = torch.tensor(
|
|
||||||
self.cont_features[idx, :], dtype=torch.float)
|
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
|
||||||
cate_tensor = torch.tensor(
|
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)
|
||||||
self.cate_features[idx, :], dtype=torch.long)
|
sex = int(self._sex[idx])
|
||||||
sex = self.basic_info.loc[patient_id, 'sex']
|
|
||||||
|
|
||||||
return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex)
|
return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex)
|
||||||
|
|
||||||
|
|||||||
25
model.py
25
model.py
@@ -263,13 +263,24 @@ def _build_time_padding_mask(
|
|||||||
event_seq: torch.Tensor,
|
event_seq: torch.Tensor,
|
||||||
time_seq: torch.Tensor,
|
time_seq: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
t_i = time_seq.unsqueeze(-1)
|
B, L = event_seq.shape
|
||||||
t_j = time_seq.unsqueeze(1)
|
device = event_seq.device
|
||||||
time_mask = (t_j <= t_i) # allow attending only to past or current
|
key_is_valid = (event_seq != 0)
|
||||||
key_is_valid = (event_seq != 0) # disallow padded positions
|
|
||||||
allowed = time_mask & key_is_valid.unsqueeze(1)
|
cache = getattr(_build_time_padding_mask, "_cache", None)
|
||||||
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
|
if cache is None:
|
||||||
return attn_mask.unsqueeze(1) # (B, 1, L, L)
|
cache = {}
|
||||||
|
setattr(_build_time_padding_mask, "_cache", cache)
|
||||||
|
cache_key = (str(device), L)
|
||||||
|
causal = cache.get(cache_key)
|
||||||
|
if causal is None:
|
||||||
|
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
|
||||||
|
cache[cache_key] = causal
|
||||||
|
|
||||||
|
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
|
||||||
|
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
|
||||||
|
attn_mask = causal | key_pad # (B,1,L,L)
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
class DelphiFork(nn.Module):
|
class DelphiFork(nn.Module):
|
||||||
|
|||||||
57
train.py
57
train.py
@@ -53,6 +53,9 @@ class TrainConfig:
|
|||||||
grad_clip: float = 1.0
|
grad_clip: float = 1.0
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
num_workers: int = 0
|
||||||
|
prefetch_factor: int = 2
|
||||||
|
persistent_workers: bool = False
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> TrainConfig:
|
def parse_args() -> TrainConfig:
|
||||||
@@ -101,6 +104,12 @@ def parse_args() -> TrainConfig:
|
|||||||
default=1.0, help="Gradient clipping value.")
|
default=1.0, help="Gradient clipping value.")
|
||||||
parser.add_argument("--weight_decay", type=float,
|
parser.add_argument("--weight_decay", type=float,
|
||||||
default=1e-2, help="Weight decay for optimizer.")
|
default=1e-2, help="Weight decay for optimizer.")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=0,
|
||||||
|
help="DataLoader workers (0 is safest on Windows).")
|
||||||
|
parser.add_argument("--prefetch_factor", type=int, default=2,
|
||||||
|
help="DataLoader prefetch factor (only used when num_workers>0).")
|
||||||
|
parser.add_argument("--persistent_workers", action='store_true',
|
||||||
|
help="Keep DataLoader workers alive between epochs (only if num_workers>0).")
|
||||||
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
help="Device to use for training.")
|
help="Device to use for training.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -121,6 +130,17 @@ class Trainer:
|
|||||||
self.device = cfg.device
|
self.device = cfg.device
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
|
use_cuda = str(self.device).startswith(
|
||||||
|
"cuda") and torch.cuda.is_available()
|
||||||
|
if use_cuda:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
try:
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if cfg.full_cov:
|
if cfg.full_cov:
|
||||||
cov_list = None
|
cov_list = None
|
||||||
else:
|
else:
|
||||||
@@ -145,17 +165,27 @@ class Trainer:
|
|||||||
],
|
],
|
||||||
generator=torch.Generator().manual_seed(cfg.random_seed),
|
generator=torch.Generator().manual_seed(cfg.random_seed),
|
||||||
)
|
)
|
||||||
|
pin_memory = use_cuda
|
||||||
|
loader_kwargs = dict(
|
||||||
|
collate_fn=health_collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
if cfg.num_workers > 0:
|
||||||
|
loader_kwargs["num_workers"] = cfg.num_workers
|
||||||
|
loader_kwargs["prefetch_factor"] = cfg.prefetch_factor
|
||||||
|
loader_kwargs["persistent_workers"] = cfg.persistent_workers
|
||||||
|
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
self.train_data,
|
self.train_data,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=health_collate_fn,
|
**loader_kwargs,
|
||||||
)
|
)
|
||||||
self.val_loader = DataLoader(
|
self.val_loader = DataLoader(
|
||||||
self.val_data,
|
self.val_data,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=health_collate_fn,
|
**loader_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.loss_type == "exponential":
|
if cfg.loss_type == "exponential":
|
||||||
@@ -274,19 +304,18 @@ class Trainer:
|
|||||||
cate_feats,
|
cate_feats,
|
||||||
sexes,
|
sexes,
|
||||||
) = batch
|
) = batch
|
||||||
event_seq = event_seq.to(self.device)
|
event_seq = event_seq.to(self.device, non_blocking=True)
|
||||||
time_seq = time_seq.to(self.device)
|
time_seq = time_seq.to(self.device, non_blocking=True)
|
||||||
cont_feats = cont_feats.to(self.device)
|
cont_feats = cont_feats.to(self.device, non_blocking=True)
|
||||||
cate_feats = cate_feats.to(self.device)
|
cate_feats = cate_feats.to(self.device, non_blocking=True)
|
||||||
sexes = sexes.to(self.device)
|
sexes = sexes.to(self.device, non_blocking=True)
|
||||||
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
||||||
if res is None:
|
if res is None:
|
||||||
continue
|
continue
|
||||||
dt, b_prev, t_prev, b_next, t_next = res
|
dt, b_prev, t_prev, b_next, t_next = res
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
lr = self.compute_lr(self.global_step)
|
lr = self.compute_lr(self.global_step)
|
||||||
for param_group in self.optimizer.param_groups:
|
self.optimizer.param_groups[0]['lr'] = lr
|
||||||
param_group['lr'] = lr
|
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
event_seq,
|
event_seq,
|
||||||
time_seq,
|
time_seq,
|
||||||
@@ -341,11 +370,11 @@ class Trainer:
|
|||||||
cate_feats,
|
cate_feats,
|
||||||
sexes,
|
sexes,
|
||||||
) = batch
|
) = batch
|
||||||
event_seq = event_seq.to(self.device)
|
event_seq = event_seq.to(self.device, non_blocking=True)
|
||||||
time_seq = time_seq.to(self.device)
|
time_seq = time_seq.to(self.device, non_blocking=True)
|
||||||
cont_feats = cont_feats.to(self.device)
|
cont_feats = cont_feats.to(self.device, non_blocking=True)
|
||||||
cate_feats = cate_feats.to(self.device)
|
cate_feats = cate_feats.to(self.device, non_blocking=True)
|
||||||
sexes = sexes.to(self.device)
|
sexes = sexes.to(self.device, non_blocking=True)
|
||||||
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
||||||
if res is None:
|
if res is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user