From 01a96d37eafe4b041c4f7d081185cd4a13cd8f17 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 8 Jan 2026 13:20:32 +0800 Subject: [PATCH] Enhance DataLoader configuration and improve tensor transfer efficiency in Trainer class --- backbones.py | 3 ++- dataset.py | 26 ++++++++++++++++-------- model.py | 25 ++++++++++++++++------- train.py | 57 +++++++++++++++++++++++++++++++++++++++------------- 4 files changed, 81 insertions(+), 30 deletions(-) diff --git a/backbones.py b/backbones.py index 4330fe1..d0efd9e 100644 --- a/backbones.py +++ b/backbones.py @@ -46,10 +46,11 @@ class SelfAttention(nn.Module): k = reshape_heads(k) v = reshape_heads(v) + dropout_p = self.attn_pdrop if self.training else 0.0 attn = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_pdrop, + dropout_p=dropout_p, ) # (B, H, L, d) attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D) diff --git a/dataset.py b/dataset.py index 04929ca..34528f2 100644 --- a/dataset.py +++ b/dataset.py @@ -35,6 +35,8 @@ class HealthDataset(Dataset): self.basic_info = basic_info.convert_dtypes() self.patient_ids = self.basic_info.index.tolist() self.patient_events = dict(patient_events) + for patient_id, records in self.patient_events.items(): + records.sort(key=lambda x: x[0]) tabular_data = tabular_data.convert_dtypes() cont_cols = [] @@ -61,17 +63,26 @@ class HealthDataset(Dataset): self.n_cont = self.cont_features.shape[1] self.n_cate = self.cate_features.shape[1] + self._doa = self.basic_info.loc[ + self.patient_ids, 'date_of_assessment' + ].to_numpy(dtype=np.float32) + self._sex = self.basic_info.loc[ + self.patient_ids, 'sex' + ].to_numpy(dtype=np.int64) + + self.cont_features = torch.from_numpy(self.cont_features) + self.cate_features = torch.from_numpy(self.cate_features) + def __len__(self) -> int: return len(self.patient_ids) def __getitem__(self, idx): patient_id = self.patient_ids[idx] - records = sorted(self.patient_events.get( - patient_id, []), key=lambda x: x[0]) + records = self.patient_events.get(patient_id, []) event_seq = [item[1] for item in records] time_seq = [item[0] for item in records] - doa = self.basic_info.loc[patient_id, 'date_of_assessment'] + doa = float(self._doa[idx]) insert_pos = np.searchsorted(time_seq, doa) time_seq.insert(insert_pos, doa) @@ -79,11 +90,10 @@ class HealthDataset(Dataset): event_seq.insert(insert_pos, 1) event_tensor = torch.tensor(event_seq, dtype=torch.long) time_tensor = torch.tensor(time_seq, dtype=torch.float) - cont_tensor = torch.tensor( - self.cont_features[idx, :], dtype=torch.float) - cate_tensor = torch.tensor( - self.cate_features[idx, :], dtype=torch.long) - sex = self.basic_info.loc[patient_id, 'sex'] + + cont_tensor = self.cont_features[idx, :].to(dtype=torch.float) + cate_tensor = self.cate_features[idx, :].to(dtype=torch.long) + sex = int(self._sex[idx]) return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex) diff --git a/model.py b/model.py index ed21378..d9e4523 100644 --- a/model.py +++ b/model.py @@ -263,13 +263,24 @@ def _build_time_padding_mask( event_seq: torch.Tensor, time_seq: torch.Tensor, ) -> torch.Tensor: - t_i = time_seq.unsqueeze(-1) - t_j = time_seq.unsqueeze(1) - time_mask = (t_j <= t_i) # allow attending only to past or current - key_is_valid = (event_seq != 0) # disallow padded positions - allowed = time_mask & key_is_valid.unsqueeze(1) - attn_mask = ~allowed # True means mask for scaled_dot_product_attention - return attn_mask.unsqueeze(1) # (B, 1, L, L) + B, L = event_seq.shape + device = event_seq.device + key_is_valid = (event_seq != 0) + + cache = getattr(_build_time_padding_mask, "_cache", None) + if cache is None: + cache = {} + setattr(_build_time_padding_mask, "_cache", cache) + cache_key = (str(device), L) + causal = cache.get(cache_key) + if causal is None: + causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1) + cache[cache_key] = causal + + causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L) + key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L) + attn_mask = causal | key_pad # (B,1,L,L) + return attn_mask class DelphiFork(nn.Module): diff --git a/train.py b/train.py index 78ce0c0..3585741 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,9 @@ class TrainConfig: grad_clip: float = 1.0 weight_decay: float = 1e-2 device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + num_workers: int = 0 + prefetch_factor: int = 2 + persistent_workers: bool = False def parse_args() -> TrainConfig: @@ -101,6 +104,12 @@ def parse_args() -> TrainConfig: default=1.0, help="Gradient clipping value.") parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay for optimizer.") + parser.add_argument("--num_workers", type=int, default=0, + help="DataLoader workers (0 is safest on Windows).") + parser.add_argument("--prefetch_factor", type=int, default=2, + help="DataLoader prefetch factor (only used when num_workers>0).") + parser.add_argument("--persistent_workers", action='store_true', + help="Keep DataLoader workers alive between epochs (only if num_workers>0).") parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to use for training.") args = parser.parse_args() @@ -121,6 +130,17 @@ class Trainer: self.device = cfg.device self.global_step = 0 + use_cuda = str(self.device).startswith( + "cuda") and torch.cuda.is_available() + if use_cuda: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + if cfg.full_cov: cov_list = None else: @@ -145,17 +165,27 @@ class Trainer: ], generator=torch.Generator().manual_seed(cfg.random_seed), ) + pin_memory = use_cuda + loader_kwargs = dict( + collate_fn=health_collate_fn, + pin_memory=pin_memory, + ) + if cfg.num_workers > 0: + loader_kwargs["num_workers"] = cfg.num_workers + loader_kwargs["prefetch_factor"] = cfg.prefetch_factor + loader_kwargs["persistent_workers"] = cfg.persistent_workers + self.train_loader = DataLoader( self.train_data, batch_size=cfg.batch_size, shuffle=True, - collate_fn=health_collate_fn, + **loader_kwargs, ) self.val_loader = DataLoader( self.val_data, batch_size=cfg.batch_size, shuffle=False, - collate_fn=health_collate_fn, + **loader_kwargs, ) if cfg.loss_type == "exponential": @@ -274,19 +304,18 @@ class Trainer: cate_feats, sexes, ) = batch - event_seq = event_seq.to(self.device) - time_seq = time_seq.to(self.device) - cont_feats = cont_feats.to(self.device) - cate_feats = cate_feats.to(self.device) - sexes = sexes.to(self.device) + event_seq = event_seq.to(self.device, non_blocking=True) + time_seq = time_seq.to(self.device, non_blocking=True) + cont_feats = cont_feats.to(self.device, non_blocking=True) + cate_feats = cate_feats.to(self.device, non_blocking=True) + sexes = sexes.to(self.device, non_blocking=True) res = get_valid_pairs_and_dt(event_seq, time_seq, 2) if res is None: continue dt, b_prev, t_prev, b_next, t_next = res self.optimizer.zero_grad() lr = self.compute_lr(self.global_step) - for param_group in self.optimizer.param_groups: - param_group['lr'] = lr + self.optimizer.param_groups[0]['lr'] = lr logits = self.model( event_seq, time_seq, @@ -341,11 +370,11 @@ class Trainer: cate_feats, sexes, ) = batch - event_seq = event_seq.to(self.device) - time_seq = time_seq.to(self.device) - cont_feats = cont_feats.to(self.device) - cate_feats = cate_feats.to(self.device) - sexes = sexes.to(self.device) + event_seq = event_seq.to(self.device, non_blocking=True) + time_seq = time_seq.to(self.device, non_blocking=True) + cont_feats = cont_feats.to(self.device, non_blocking=True) + cate_feats = cate_feats.to(self.device, non_blocking=True) + sexes = sexes.to(self.device, non_blocking=True) res = get_valid_pairs_and_dt(event_seq, time_seq, 2) if res is None: continue