Skip to content

Commit 928d2c0

Browse files
cenzhaometafacebook-github-bot
authored andcommitted
support cuda-graph mode
Summary: introduce `--graph_launches=10` (default=0) as a knob to enable cuda-graph mode, when it's non-zero, it will replay the graph that many times. in cuda-graph mode: 1. warm up, run coll `warm_iters` number of times on a separate stream and sync with current-stream. 2. capturing graph, run collective `iters`. 3. replay graph `graph_launches` number of times on current-stream. > param-bench measures collective latency from CPU side, which is not very accurate. see test plan: for trace with graph-mode (saw graph-launches) etc. > TODO: cuda-graph mode doesn't like `async-op=True` case, it produces following error, need to follow up with a separate PTD fix ``` [rank7]: Traceback (most recent call last): [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__run_xar_main__.py", line 140, in <module> [rank7]: __invoke_main() [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main [rank7]: run_as_main(main_module, main_function) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main [rank7]: oss_run_as_main( [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main [rank7]: main() [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper [rank7]: return f(*args, **kwargs) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 1226, in main [rank7]: remote_mpi_launcher(args, more_args) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 475, in remote_mpi_launcher [rank7]: local_launcher(args, more_args) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 368, in local_launcher [rank7]: commsBench() [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 268, in commsBench [rank7]: comms_bench() [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1523, in main [rank7]: collBenchObj.runBench(commsParams) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1458, in runBench [rank7]: self.backendFuncs.benchmark_comms(self.benchTime, commsParams) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/pytorch_dist_backend.py", line 1206, in benchmark_comms [rank7]: benchTime(index, commsParams, self) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1236, in benchTime [rank7]: self.benchComm(index, commsParams, backendFuncs) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1310, in benchComm [rank7]: self.runColl( [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 431, in runColl [rank7]: return self.run_coll_cuda_graph(comm_fn, dcheck) [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 377, in run_coll_cuda_graph [rank7]: with torch.cuda.graph(g): [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/cuda/graphs.py", line 186, in __exit__ [rank7]: self.cuda_graph.capture_end() [rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/cuda/graphs.py", line 84, in capture_end [rank7]: super().capture_end() [rank7]: RuntimeError: HIP error: capturing stream has unjoined work [rank7]: HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank7]: For debugging consider passing AMD_SERIALIZE_KERNEL=3 [rank7]: Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions. ``` Reviewed By: kingchc, kwen2501 Differential Revision: D70544123 fbshipit-source-id: bb4a5ad8ad1e03a77e8d3528e17d26125b5fe355
1 parent a81194f commit 928d2c0

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

train/comms/pt/comms.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ def readArgs(self, parser):
189189
default=False,
190190
help="use device time measurement",
191191
)
192+
parser.add_argument(
193+
"--graph-launches",
194+
type=int,
195+
default=0,
196+
help="Number of graph launches for each data-size",
197+
)
192198
return parser.parse_known_args()
193199

194200
def _checkPt2Pt(self, args):
@@ -315,6 +321,10 @@ def checkArgs(self, args): # noqa: C901
315321
logger.error(f"wrong dst_ranks ({args.dst_ranks})")
316322
comms_utils.gracefulExit()
317323

324+
if args.graph_launches > 0 and args.device != "cuda":
325+
logger.error("cuda graph is only supported for cuda or rocm device")
326+
comms_utils.gracefulExit()
327+
318328
# depnds on data type
319329
def checkArgsdataType(self, args): # noqa: C901
320330
args.b = comms_utils.parsesize(args.b)
@@ -354,7 +364,81 @@ def checkArgsdataType(self, args): # noqa: C901
354364
# run a few sanity checks
355365
self._check_bitwidth(args)
356366

367+
def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
368+
self.backendFuncs.sync_barrier(
369+
self.collectiveArgs, desc="run_coll_cuda_graph_begin"
370+
)
371+
elapsedTimeNS = 0.0
372+
373+
# 1. Warmup phase
374+
# launch collective on a separate stream and sync with current_stream
375+
s = torch.cuda.Stream()
376+
s.wait_stream(torch.cuda.current_stream())
377+
with torch.cuda.stream(s):
378+
for _ in range(self.collectiveArgs.numWarmupIters):
379+
comm_fn(self.collectiveArgs)
380+
torch.cuda.current_stream().wait_stream(s)
381+
382+
# 2. capturing graph
383+
# in cuda graph, we need to use sync mode
384+
# TODO: this might need PTD fix (async_op=True won't work under cuda graph)
385+
self.collectiveArgs.asyncOp = False
386+
g = torch.cuda.CUDAGraph()
387+
with torch.cuda.graph(g):
388+
for _ in range(self.collectiveArgs.numIters):
389+
if dcheck:
390+
# reset input tensor for data validation
391+
self.setTensorVal(self.collectiveArgs.ipTensor)
392+
comm_fn(self.collectiveArgs)
393+
394+
# 3. Replay
395+
start = time.monotonic() # available only in py3
396+
for _ in range(self.collectiveArgs.graph_launches):
397+
if self.collectiveArgs.enable_profiler:
398+
comms_utils.sampleProfiler()
399+
400+
# [optional] we can feed new input data to ipTensor for each replay
401+
g.replay()
402+
403+
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
404+
end = time.monotonic() # available only in py3
405+
406+
ensureTensorFlush(self.collectiveArgs.opTensor)
407+
408+
elapsedTimeNS += (
409+
end - start
410+
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
411+
412+
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
413+
414+
avgIterNS, algBW = comms_utils.getAlgBW(
415+
elapsedTimeNS,
416+
memSize,
417+
self.collectiveArgs.numIters
418+
* self.collectiveArgs.numCollPerIter
419+
* self.collectiveArgs.graph_launches,
420+
)
421+
busBW = self.backendFuncs.getBusBW(
422+
self.collectiveArgs.collective,
423+
algBW,
424+
self.collectiveArgs,
425+
)
426+
427+
# reset group to sync among all global ranks
428+
self.collectiveArgs.group = self.backendFuncs.get_default_group()
429+
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_end")
430+
431+
results = {
432+
"timeUS": avgIterNS / 1e3,
433+
"algBW": algBW,
434+
"busBW": busBW,
435+
"memSize": memSize,
436+
}
437+
return results
438+
357439
def runColl(self, comm_fn=None, dcheck=False):
440+
if self.collectiveArgs.graph_launches > 0:
441+
return self.run_coll_cuda_graph(comm_fn, dcheck)
358442
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_begin")
359443

360444
elapsedCPUTimeNS = 0.0
@@ -801,6 +885,7 @@ def initCollectiveArgs(self, commsParams):
801885
self.collectiveArgs.numCollPerIter = commsParams.num_coll
802886
self.collectiveArgs.include_0B = commsParams.include_0B
803887
self.collectiveArgs.use_device_time = commsParams.use_device_time
888+
self.collectiveArgs.graph_launches = commsParams.graph_launches
804889

805890
if commsParams.bitwidth < 32:
806891
comms_utils.initQuantCommCtx(self.collectiveArgs, commsParams)

train/comms/pt/comms_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,7 @@ def __init__(
883883
self.groupRanks = groupRanks
884884

885885
self.include_0B = args.include_0B
886+
self.graph_launches = args.graph_launches
886887
self.num_coll = args.num_coll
887888

888889

train/comms/pt/pytorch_backend_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(self) -> None:
128128

129129
self.include_0B = False
130130
self.use_device_time = False
131+
self.graph_launches = 0
131132

132133

133134
class backendFunctions(ABC):

train/comms/pt/pytorch_dist_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def get_current_stream(self, device: torch.device | None):
10231023
return None
10241024

10251025
def switch_stream(self, stream, device: torch.device | None):
1026-
"""switch to a new stream and return the current stream"""
1026+
"""switch to a new stream and return the old current stream"""
10271027
if device is None:
10281028
device = self.get_device()
10291029
if stream is not None and device.type == "cuda":

0 commit comments

Comments
 (0)