Enhance _maybe_torch_compile and add _maybe_cudagraph_mark_step_begin for improved CUDA Graphs handling

This commit is contained in:
2026-01-18 18:04:54 +08:00
parent a4b19b6e08
commit 6e76d67a10

View File

@@ -36,18 +36,45 @@ warnings.filterwarnings('ignore')
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module: def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
"""Best-effort torch.compile() wrapper (PyTorch 2.x).""" """Best-effort torch.compile() wrapper (PyTorch 2.x).
Notes:
- Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes.
If you keep references to graph outputs across steps, PyTorch may raise:
"accessing tensor output of CUDAGraphs that has been overwritten".
- We default to settings that avoid cudagraph output-lifetime pitfalls.
"""
if not enabled: if not enabled:
return module return module
try: try:
torch_compile = getattr(torch, "compile", None) torch_compile = getattr(torch, "compile", None)
if torch_compile is None: if torch_compile is None:
return module return module
return torch_compile(module, mode="reduce-overhead") # Prefer a safer mode for evaluation code; best-effort disable cudagraphs.
kwargs = {"mode": "default"}
try:
kwargs["options"] = {"triton.cudagraphs": False}
except Exception:
pass
return torch_compile(module, **kwargs)
except Exception: except Exception:
return module return module
def _maybe_cudagraph_mark_step_begin() -> None:
"""Best-effort step marker for CUDA Graphs compiled execution."""
try:
compiler_mod = getattr(torch, "compiler", None)
if compiler_mod is None:
return
mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None)
if mark is None:
return
mark()
except Exception:
return
def _ensure_dir(path: str) -> str: def _ensure_dir(path: str) -> str:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
@@ -485,6 +512,9 @@ class LandmarkEvaluator:
# Get model predictions at anchor points # Get model predictions at anchor points
if has_anchor.any(): if has_anchor.any():
# If torch.compile uses CUDA Graphs under the hood, mark a new step
# before each compiled invocation to avoid output lifetime issues.
_maybe_cudagraph_mark_step_begin()
# Forward pass # Forward pass
hidden = self.model(event_batch, time_batch, hidden = self.model(event_batch, time_batch,
sex_batch, cont_batch, cate_batch) sex_batch, cont_batch, cate_batch)
@@ -1152,6 +1182,7 @@ class LandmarkEvaluator:
batch_idx = torch.arange(B, device=self.device) batch_idx = torch.arange(B, device=self.device)
# Backbone once per batch # Backbone once per batch
_maybe_cudagraph_mark_step_begin()
hidden = self.model( hidden = self.model(
# (B, L, D) # (B, L, D)
event_batch, time_batch, sex_batch, cont_batch, cate_batch) event_batch, time_batch, sex_batch, cont_batch, cate_batch)