diff --git a/evaluate.py b/evaluate.py index a42e43f..fb2ca66 100644 --- a/evaluate.py +++ b/evaluate.py @@ -36,18 +36,45 @@ warnings.filterwarnings('ignore') def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module: - """Best-effort torch.compile() wrapper (PyTorch 2.x).""" + """Best-effort torch.compile() wrapper (PyTorch 2.x). + + Notes: + - Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes. + If you keep references to graph outputs across steps, PyTorch may raise: + "accessing tensor output of CUDAGraphs that has been overwritten". + - We default to settings that avoid cudagraph output-lifetime pitfalls. + """ if not enabled: return module try: torch_compile = getattr(torch, "compile", None) if torch_compile is None: return module - return torch_compile(module, mode="reduce-overhead") + # Prefer a safer mode for evaluation code; best-effort disable cudagraphs. + kwargs = {"mode": "default"} + try: + kwargs["options"] = {"triton.cudagraphs": False} + except Exception: + pass + return torch_compile(module, **kwargs) except Exception: return module +def _maybe_cudagraph_mark_step_begin() -> None: + """Best-effort step marker for CUDA Graphs compiled execution.""" + try: + compiler_mod = getattr(torch, "compiler", None) + if compiler_mod is None: + return + mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None) + if mark is None: + return + mark() + except Exception: + return + + def _ensure_dir(path: str) -> str: os.makedirs(path, exist_ok=True) return path @@ -485,6 +512,9 @@ class LandmarkEvaluator: # Get model predictions at anchor points if has_anchor.any(): + # If torch.compile uses CUDA Graphs under the hood, mark a new step + # before each compiled invocation to avoid output lifetime issues. + _maybe_cudagraph_mark_step_begin() # Forward pass hidden = self.model(event_batch, time_batch, sex_batch, cont_batch, cate_batch) @@ -1152,6 +1182,7 @@ class LandmarkEvaluator: batch_idx = torch.arange(B, device=self.device) # Backbone once per batch + _maybe_cudagraph_mark_step_begin() hidden = self.model( # (B, L, D) event_batch, time_batch, sex_batch, cont_batch, cate_batch)