diff --git a/evaluate.py b/evaluate.py index 64ab22e..a42e43f 100644 --- a/evaluate.py +++ b/evaluate.py @@ -35,6 +35,19 @@ from losses import ( warnings.filterwarnings('ignore') +def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module: + """Best-effort torch.compile() wrapper (PyTorch 2.x).""" + if not enabled: + return module + try: + torch_compile = getattr(torch, "compile", None) + if torch_compile is None: + return module + return torch_compile(module, mode="reduce-overhead") + except Exception: + return module + + def _ensure_dir(path: str) -> str: os.makedirs(path, exist_ok=True) return path @@ -225,11 +238,10 @@ class LandmarkEvaluator: device: str = 'cuda', batch_size: int = 256, num_workers: int = 4, + compile_model: bool = True, ): - self.model = model.to(device) - self.model.eval() - self.head = head.to(device) - self.head.eval() + self.model = model.to(device).eval() + self.head = head.to(device).eval() self.loss_fn = loss_fn self.dataset = dataset self.eval_indices = eval_indices @@ -237,6 +249,22 @@ class LandmarkEvaluator: self.batch_size = batch_size self.num_workers = num_workers + use_cuda = str(self.device).startswith( + "cuda") and torch.cuda.is_available() + if use_cuda: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + + # JIT/compile optimization (best effort) + if compile_model and use_cuda: + self.model = _maybe_torch_compile(self.model, enabled=True) + self.head = _maybe_torch_compile(self.head, enabled=True) + # Evaluation parameters from design doc self.age_cutoffs = [50, 60, 70] self.horizons = [0.25, 0.5, 1, 2, 5, 10] @@ -247,6 +275,150 @@ class LandmarkEvaluator: self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs] self.horizons_days = [h * 365.25 for h in self.horizons] + @staticmethod + def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor: + """Compute last observed (non-padding) time per patient.""" + real_mask = event_batch >= 1 + masked = time_batch.masked_fill(~real_mask, float('-inf')) + return masked.max(dim=1).values + + @staticmethod + def _anchor_indices( + time_batch: torch.Tensor, + event_batch: torch.Tensor, + cutoff_days: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Find anchor index/time: last valid record before cutoff.""" + real_mask = event_batch >= 1 + before = time_batch < cutoff_days + valid_before = real_mask & before + has_anchor = valid_before.any(dim=1) + + # argmax of position under mask gives last True position + L = event_batch.size(1) + pos = torch.arange(L, device=event_batch.device).view(1, L) + anchor_idx = (valid_before.to(torch.long) * + pos).max(dim=1).values.to(torch.long) + t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1) + return has_anchor, anchor_idx, t_anchor + + def _labels_and_validity_for_cutoff( + self, + time_batch: torch.Tensor, + event_batch: torch.Tensor, + cutoff_days: float, + horizons_days: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Vectorized label + validity computation for all horizons at a cutoff. + + Returns: + labels: (B, H, K) float32 {0,1} + valid_cc: (B, H, K) bool + valid_clean: (B, H, K) bool + """ + + n_tech_tokens = 2 + K = int(self.dataset.n_disease) + death_code = int(K - 1) + + B, L = event_batch.shape + H = int(horizons_days.numel()) + + # Disease token mask and indices + is_disease = event_batch >= n_tech_tokens + disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1) + + # ever_has_disease: (B, K) + ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device) + if is_disease.any(): + b_idx, t_idx = is_disease.nonzero(as_tuple=True) + d_idx = disease_idx[b_idx, t_idx] + ever[b_idx, d_idx] = True + + # Events within horizon windows: (B, L, H) + offset = time_batch - float(cutoff_days) + within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & ( + offset.unsqueeze(-1) <= horizons_days.view(1, 1, H) + ) + + labels_bool = torch.zeros( + (B, H, K), dtype=torch.bool, device=event_batch.device) + if within.any(): + b2, t2, h2 = within.nonzero(as_tuple=True) + d2 = disease_idx[b2, t2] + labels_bool[b2, h2, d2] = True + + labels = labels_bool.to(torch.float32) + + last_time = self._last_time(time_batch, event_batch) # (B,) + horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H) + + death_in_horizon = labels_bool[:, :, death_code] # (B, H) + observed_past_horizon = last_time.view(B, 1) > horizon_end + lost_within_horizon = last_time.view(B, 1) <= horizon_end + + # Track A (Complete-Case): + # - if observed past horizon OR death in horizon => valid all diseases + # - else (censored within horizon) => valid only for diseases that occurred within horizon + valid_cc = labels_bool.clone() + full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1) + if full_mask.any(): + valid_cc = torch.where( + full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc) + + # Track B (Clean-Control) per disease: + # valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window) + never = ~ever # (B, K) + valid_clean = (~death_in_horizon).unsqueeze(-1) & ( + labels_bool | (never.unsqueeze(1) & ( + ~lost_within_horizon).unsqueeze(-1)) + ) + + return labels, valid_cc, valid_clean + + def _compute_risk_scores_many_horizons( + self, + logits: torch.Tensor, + t_start_days: torch.Tensor, + horizons_days: torch.Tensor, + ) -> torch.Tensor: + """Compute risk increments for all horizons in one vectorized call. + + Args: + logits: model head outputs for anchor points. + t_start_days: (B,) time from anchor to cutoff (days). + horizons_days: (H,) horizons in days. + + Returns: + risk: (B, H, K) float32 + """ + t_start_days = torch.clamp(t_start_days, min=0) + t_end_days = torch.clamp(t_start_days.unsqueeze( + 1) + horizons_days.view(1, -1), min=0) + + t_query_years = torch.cat([t_start_days.unsqueeze( + 1), t_end_days], dim=1) / 365.25 # (B, H+1) + + # calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T) + if hasattr(self.loss_fn, "calculate_cifs"): + cifs = self.loss_fn.calculate_cifs( + logits, t_query_years, return_survival=False) + else: + raise ValueError( + f"Loss function does not support calculate_cifs: {type(self.loss_fn)}") + + if cifs.ndim == 2: + # (B, K) -> (B, 1, K) + cifs_bt_k = cifs.unsqueeze(1) + else: + # (B, K, T) -> (B, T, K) + cifs_bt_k = cifs.permute(0, 2, 1) + + cif_start = cifs_bt_k[:, :1, :] # (B, 1, K) + cif_end = cifs_bt_k[:, 1:, :] # (B, H, K) + risk = torch.clamp(cif_end - cif_start, min=0) + return risk + @torch.no_grad() def compute_risk_scores( self, @@ -926,22 +1098,169 @@ class LandmarkEvaluator: return results def run_full_evaluation(self) -> Dict: - """ - Run complete landmark analysis across all cutoffs and horizons. + """Run the full evaluation using a single-pass DataLoader. - Returns: - all_results: Nested dictionary with all evaluation results + Key optimizations: + - iterate DataLoader exactly once + - run transformer backbone once per batch + - reuse hidden states per cutoff (3x head only) + - vectorize CIF/risk over all horizons in one call """ - all_results = { + + # Build evaluation subset loader + indices = self.eval_indices if self.eval_indices is not None else list( + range(len(self.dataset))) + subset = Subset(self.dataset, indices) + loader = DataLoader( + subset, + batch_size=self.batch_size, + shuffle=False, + collate_fn=health_collate_fn, + num_workers=self.num_workers, + pin_memory=True if str(self.device).startswith('cuda') else False, + ) + + cutoffs_days = torch.tensor( + # (C,) + self.age_cutoffs_days, dtype=torch.float32, device=self.device) + horizons_days = torch.tensor( + # (H,) + self.horizons_days, dtype=torch.float32, device=self.device) + C = int(cutoffs_days.numel()) + H = int(horizons_days.numel()) + K = int(self.dataset.n_disease) + + # Buffers: store per landmark/track arrays in chunks to avoid repeated I/O. + # Each key stores lists of numpy arrays to be concatenated at the end. + buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {} + for ci in range(C): + for hi in range(H): + for track in ("complete_case", "clean_control"): + buffers[(ci, hi, track)] = { + "risk": [], "labels": [], "valid": []} + + with torch.inference_mode(): + for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100): + event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch + event_batch = event_batch.to(self.device, non_blocking=True) + time_batch = time_batch.to(self.device, non_blocking=True) + cont_batch = cont_batch.to(self.device, non_blocking=True) + cate_batch = cate_batch.to(self.device, non_blocking=True) + sex_batch = sex_batch.to(self.device, non_blocking=True) + + B, L = event_batch.shape + batch_idx = torch.arange(B, device=self.device) + + # Backbone once per batch + hidden = self.model( + # (B, L, D) + event_batch, time_batch, sex_batch, cont_batch, cate_batch) + + for ci in range(C): + cutoff = float(cutoffs_days[ci].item()) + + has_anchor, anchor_idx, t_anchor = self._anchor_indices( + time_batch, event_batch, cutoff) + if not has_anchor.any(): + continue + + # Hidden states at anchor positions + hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D) + logits = self.head(hidden_anchor) + + # Vectorized labels/validity for all horizons + labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff( + time_batch, event_batch, cutoff, horizons_days + ) + + # Risk scores for all horizons (B, H, K) + t_start = torch.clamp(torch.tensor( + cutoff, device=self.device) - t_anchor, min=0) + risk_bhk = self._compute_risk_scores_many_horizons( + logits, t_start, horizons_days) + + # Apply anchor constraint to validity + anchor_mask = has_anchor.view(B, 1, 1) + valid_cc_bhk = valid_cc_bhk & anchor_mask + valid_clean_bhk = valid_clean_bhk & anchor_mask + + # Push per-horizon chunks + for hi in range(H): + for track, valid_bk in ( + ("complete_case", valid_cc_bhk[:, hi, :]), + ("clean_control", valid_clean_bhk[:, hi, :]), + ): + row_mask = valid_bk.any(dim=1) + if not row_mask.any(): + continue + + r = risk_bhk[row_mask, hi, :].to( + torch.float32).cpu().numpy() + y = labels_bhk[row_mask, hi, :].to( + torch.float32).cpu().numpy() + m = valid_bk[row_mask, :].to( + torch.bool).cpu().numpy() + + buffers[(ci, hi, track)]["risk"].append(r) + buffers[(ci, hi, track)]["labels"].append(y) + buffers[(ci, hi, track)]["valid"].append(m) + + # Assemble results in the original output schema + all_results: Dict = { 'age_cutoffs': self.age_cutoffs, 'horizons': self.horizons, 'landmarks': [], } - # Evaluate each landmark - for age_cutoff in self.age_cutoffs: - for horizon in self.horizons: - landmark_results = self.evaluate_landmark(age_cutoff, horizon) + for ci, age in enumerate(self.age_cutoffs): + for hi, horizon in enumerate(self.horizons): + landmark_results = { + 'age_cutoff': age, + 'horizon': horizon, + 'complete_case': {}, + 'clean_control': {}, + } + + for track in ("complete_case", "clean_control"): + chunks = buffers[(ci, hi, track)] + if len(chunks["risk"]) == 0: + continue + + risk_scores = np.concatenate(chunks["risk"], axis=0) + labels = np.concatenate(chunks["labels"], axis=0) + valid_mask = np.concatenate(chunks["valid"], axis=0) + + auc_scores = self.compute_auc_per_disease( + risk_scores, labels, valid_mask) + mean_auc = np.nanmean(list(auc_scores.values())) + + track_out = { + 'n_patients': int(valid_mask.shape[0]), + 'n_valid': int(valid_mask.sum()), + 'n_valid_patients': int((valid_mask.any(axis=1)).sum()), + 'auc_per_disease': auc_scores, + 'mean_auc': mean_auc, + } + + if track == "complete_case": + brier_metrics = self.compute_brier_score( + risk_scores, labels, valid_mask) + capture_metrics = self.compute_disease_capture_at_k( + risk_scores, labels, valid_mask) + lift_yield_metrics = self.compute_lift_and_yield( + risk_scores, labels, valid_mask) + dca_metrics = self.compute_dca_net_benefit( + risk_scores, labels, valid_mask) + track_out.update({ + 'brier_score': brier_metrics['brier_score'], + 'brier_skill_score': brier_metrics['brier_skill_score'], + 'disease_capture_at_k': capture_metrics, + 'lift_and_yield': lift_yield_metrics, + 'dca': dca_metrics, + }) + + landmark_results[track] = track_out + all_results['landmarks'].append(landmark_results) return all_results @@ -1202,6 +1521,11 @@ def main(): default=4, help='Number of data loader workers' ) + parser.add_argument( + '--no_compile', + action='store_true', + help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)' + ) args = parser.parse_args() @@ -1219,6 +1543,7 @@ def main(): device=args.device, batch_size=args.batch_size, num_workers=args.num_workers, + compile_model=(not args.no_compile), ) # Run evaluation