Commit 928d2c0
support cuda-graph mode
Summary:
introduce `--graph_launches=10` (default=0) as a knob to enable cuda-graph mode, when it's non-zero, it will replay the graph that many times.
in cuda-graph mode:
1. warm up, run coll `warm_iters` number of times on a separate stream and sync with current-stream.
2. capturing graph, run collective `iters`.
3. replay graph `graph_launches` number of times on current-stream.
> param-bench measures collective latency from CPU side, which is not very accurate.
see test plan: for trace with graph-mode (saw graph-launches) etc.
> TODO: cuda-graph mode doesn't like `async-op=True` case, it produces following error, need to follow up with a separate PTD fix
```
[rank7]: Traceback (most recent call last):
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__run_xar_main__.py", line 140, in <module>
[rank7]: __invoke_main()
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main
[rank7]: run_as_main(main_module, main_function)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main
[rank7]: oss_run_as_main(
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main
[rank7]: main()
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
[rank7]: return f(*args, **kwargs)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 1226, in main
[rank7]: remote_mpi_launcher(args, more_args)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 475, in remote_mpi_launcher
[rank7]: local_launcher(args, more_args)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 368, in local_launcher
[rank7]: commsBench()
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/fb/launcher.py", line 268, in commsBench
[rank7]: comms_bench()
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1523, in main
[rank7]: collBenchObj.runBench(commsParams)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1458, in runBench
[rank7]: self.backendFuncs.benchmark_comms(self.benchTime, commsParams)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/pytorch_dist_backend.py", line 1206, in benchmark_comms
[rank7]: benchTime(index, commsParams, self)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1236, in benchTime
[rank7]: self.benchComm(index, commsParams, backendFuncs)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 1310, in benchComm
[rank7]: self.runColl(
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 431, in runColl
[rank7]: return self.run_coll_cuda_graph(comm_fn, dcheck)
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/param_bench/train/comms/pt/comms.py", line 377, in run_coll_cuda_graph
[rank7]: with torch.cuda.graph(g):
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/cuda/graphs.py", line 186, in __exit__
[rank7]: self.cuda_graph.capture_end()
[rank7]: File "/mnt/xarfuse/uid-0/5d817754-seed-nspid4026531836_cgpid202510957-ns-4026531841/torch/cuda/graphs.py", line 84, in capture_end
[rank7]: super().capture_end()
[rank7]: RuntimeError: HIP error: capturing stream has unjoined work
[rank7]: HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank7]: For debugging consider passing AMD_SERIALIZE_KERNEL=3
[rank7]: Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
```
Reviewed By: kingchc, kwen2501
Differential Revision: D70544123
fbshipit-source-id: bb4a5ad8ad1e03a77e8d3528e17d26125b5fe3551 parent a81194f commit 928d2c0
File tree
4 files changed
+88
-1
lines changed- train/comms/pt
4 files changed
+88
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
189 | 189 | | |
190 | 190 | | |
191 | 191 | | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
192 | 198 | | |
193 | 199 | | |
194 | 200 | | |
| |||
315 | 321 | | |
316 | 322 | | |
317 | 323 | | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
318 | 328 | | |
319 | 329 | | |
320 | 330 | | |
| |||
354 | 364 | | |
355 | 365 | | |
356 | 366 | | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
357 | 439 | | |
| 440 | + | |
| 441 | + | |
358 | 442 | | |
359 | 443 | | |
360 | 444 | | |
| |||
801 | 885 | | |
802 | 886 | | |
803 | 887 | | |
| 888 | + | |
804 | 889 | | |
805 | 890 | | |
806 | 891 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
883 | 883 | | |
884 | 884 | | |
885 | 885 | | |
| 886 | + | |
886 | 887 | | |
887 | 888 | | |
888 | 889 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
128 | 128 | | |
129 | 129 | | |
130 | 130 | | |
| 131 | + | |
131 | 132 | | |
132 | 133 | | |
133 | 134 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1023 | 1023 | | |
1024 | 1024 | | |
1025 | 1025 | | |
1026 | | - | |
| 1026 | + | |
1027 | 1027 | | |
1028 | 1028 | | |
1029 | 1029 | | |
| |||
0 commit comments