Enhance _maybe_torch_compile and add _maybe_cudagraph_mark_step_begin for improved CUDA Graphs handling
This commit is contained in:
35
evaluate.py
35
evaluate.py
@@ -36,18 +36,45 @@ warnings.filterwarnings('ignore')
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
||||||
"""Best-effort torch.compile() wrapper (PyTorch 2.x)."""
|
"""Best-effort torch.compile() wrapper (PyTorch 2.x).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes.
|
||||||
|
If you keep references to graph outputs across steps, PyTorch may raise:
|
||||||
|
"accessing tensor output of CUDAGraphs that has been overwritten".
|
||||||
|
- We default to settings that avoid cudagraph output-lifetime pitfalls.
|
||||||
|
"""
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return module
|
return module
|
||||||
try:
|
try:
|
||||||
torch_compile = getattr(torch, "compile", None)
|
torch_compile = getattr(torch, "compile", None)
|
||||||
if torch_compile is None:
|
if torch_compile is None:
|
||||||
return module
|
return module
|
||||||
return torch_compile(module, mode="reduce-overhead")
|
# Prefer a safer mode for evaluation code; best-effort disable cudagraphs.
|
||||||
|
kwargs = {"mode": "default"}
|
||||||
|
try:
|
||||||
|
kwargs["options"] = {"triton.cudagraphs": False}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return torch_compile(module, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_cudagraph_mark_step_begin() -> None:
|
||||||
|
"""Best-effort step marker for CUDA Graphs compiled execution."""
|
||||||
|
try:
|
||||||
|
compiler_mod = getattr(torch, "compiler", None)
|
||||||
|
if compiler_mod is None:
|
||||||
|
return
|
||||||
|
mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None)
|
||||||
|
if mark is None:
|
||||||
|
return
|
||||||
|
mark()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def _ensure_dir(path: str) -> str:
|
def _ensure_dir(path: str) -> str:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
@@ -485,6 +512,9 @@ class LandmarkEvaluator:
|
|||||||
|
|
||||||
# Get model predictions at anchor points
|
# Get model predictions at anchor points
|
||||||
if has_anchor.any():
|
if has_anchor.any():
|
||||||
|
# If torch.compile uses CUDA Graphs under the hood, mark a new step
|
||||||
|
# before each compiled invocation to avoid output lifetime issues.
|
||||||
|
_maybe_cudagraph_mark_step_begin()
|
||||||
# Forward pass
|
# Forward pass
|
||||||
hidden = self.model(event_batch, time_batch,
|
hidden = self.model(event_batch, time_batch,
|
||||||
sex_batch, cont_batch, cate_batch)
|
sex_batch, cont_batch, cate_batch)
|
||||||
@@ -1152,6 +1182,7 @@ class LandmarkEvaluator:
|
|||||||
batch_idx = torch.arange(B, device=self.device)
|
batch_idx = torch.arange(B, device=self.device)
|
||||||
|
|
||||||
# Backbone once per batch
|
# Backbone once per batch
|
||||||
|
_maybe_cudagraph_mark_step_begin()
|
||||||
hidden = self.model(
|
hidden = self.model(
|
||||||
# (B, L, D)
|
# (B, L, D)
|
||||||
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user