Enhance _maybe_torch_compile and add _maybe_cudagraph_mark_step_begin for improved CUDA Graphs handling
This commit is contained in:
35
evaluate.py
35
evaluate.py
@@ -36,18 +36,45 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
||||
"""Best-effort torch.compile() wrapper (PyTorch 2.x)."""
|
||||
"""Best-effort torch.compile() wrapper (PyTorch 2.x).
|
||||
|
||||
Notes:
|
||||
- Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes.
|
||||
If you keep references to graph outputs across steps, PyTorch may raise:
|
||||
"accessing tensor output of CUDAGraphs that has been overwritten".
|
||||
- We default to settings that avoid cudagraph output-lifetime pitfalls.
|
||||
"""
|
||||
if not enabled:
|
||||
return module
|
||||
try:
|
||||
torch_compile = getattr(torch, "compile", None)
|
||||
if torch_compile is None:
|
||||
return module
|
||||
return torch_compile(module, mode="reduce-overhead")
|
||||
# Prefer a safer mode for evaluation code; best-effort disable cudagraphs.
|
||||
kwargs = {"mode": "default"}
|
||||
try:
|
||||
kwargs["options"] = {"triton.cudagraphs": False}
|
||||
except Exception:
|
||||
pass
|
||||
return torch_compile(module, **kwargs)
|
||||
except Exception:
|
||||
return module
|
||||
|
||||
|
||||
def _maybe_cudagraph_mark_step_begin() -> None:
|
||||
"""Best-effort step marker for CUDA Graphs compiled execution."""
|
||||
try:
|
||||
compiler_mod = getattr(torch, "compiler", None)
|
||||
if compiler_mod is None:
|
||||
return
|
||||
mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None)
|
||||
if mark is None:
|
||||
return
|
||||
mark()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _ensure_dir(path: str) -> str:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
@@ -485,6 +512,9 @@ class LandmarkEvaluator:
|
||||
|
||||
# Get model predictions at anchor points
|
||||
if has_anchor.any():
|
||||
# If torch.compile uses CUDA Graphs under the hood, mark a new step
|
||||
# before each compiled invocation to avoid output lifetime issues.
|
||||
_maybe_cudagraph_mark_step_begin()
|
||||
# Forward pass
|
||||
hidden = self.model(event_batch, time_batch,
|
||||
sex_batch, cont_batch, cate_batch)
|
||||
@@ -1152,6 +1182,7 @@ class LandmarkEvaluator:
|
||||
batch_idx = torch.arange(B, device=self.device)
|
||||
|
||||
# Backbone once per batch
|
||||
_maybe_cudagraph_mark_step_begin()
|
||||
hidden = self.model(
|
||||
# (B, L, D)
|
||||
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
||||
|
||||
Reference in New Issue
Block a user