Skip to content

Commit f9c74a6

Browse files
shengfukevinfacebook-github-bot
authored andcommitted
Sanshang: add support to metrics calculation (#197)
Summary: Add support to metrics calculation. 1. Iteration E2E time 2. bandwidth This is the copy of #195 for importing it into Meta. Reviewed By: briancoutinho Differential Revision: D69155629 Pulled By: shengfukevin
1 parent c5f8d06 commit f9c74a6

File tree

7 files changed

+697
-114
lines changed

7 files changed

+697
-114
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.venv/
22
__pycache__/
3+
./et_replay/vendor_internal/

et_replay/comm/backend/pytorch_dist_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,13 @@ def barrier(self, collectiveArgs, name="dummy", retFlag=False):
623623
if retFlag:
624624
return retObj
625625

626+
def barrier_all_ranks(self):
627+
dist.barrier(
628+
device_ids=[self.get_device().index]
629+
if dist.get_backend() == "nccl"
630+
else None
631+
)
632+
626633
def sync_barrier(self, collectiveArgs, desc="dummy"):
627634
# ensure all streams have finished outstanding events before calling barrier
628635
self.complete_accel_ops(collectiveArgs)
@@ -1031,7 +1038,7 @@ def initialize_groups(self, backend="gloo"):
10311038
# even if they are not going to be members of the group.
10321039
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
10331040
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
1034-
torch.distributed.barrier()
1041+
self.barrier_all_ranks()
10351042

10361043
idxed_group_ranks_to_pgId: dict[tuple[int], list[int]] = defaultdict(list)
10371044
for i in range(self.get_world_size()):

et_replay/comm/comms_utils.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
from typing import Any
2121

2222
try:
23-
from param_bench.train.comms.pt.fb.internals import (
24-
fbInitProfiler,
25-
fbSampleProfiler,
26-
fbStartProfiler,
23+
from param_bench.et_replay.vendor_internals import (
2724
initialize_collectiveArgs_internal,
2825
remove_quantization_handlers,
2926
)
@@ -390,47 +387,6 @@ def ensureTensorFlush(tensors: list[torch.Tensor] | torch.Tensor) -> Any:
390387
return x
391388

392389

393-
def startProfiler(rank: int, device: str, numWarmupIters: int, numIters: int) -> bool:
394-
"""
395-
Starts internal profiler with given parameters.
396-
397-
Args:
398-
rank: Global rank.
399-
device: Type of device "cuda", "cpu", etc.
400-
numWarmupIters: Number of warmup iterations.
401-
numIters: Number of real iterations.
402-
Returns:
403-
bool: Returns if internal profile was able to start or not.
404-
"""
405-
if has_internal_libs:
406-
fbInitProfiler(
407-
rank=rank,
408-
device=device,
409-
warmup=numWarmupIters,
410-
iters=numIters,
411-
)
412-
fbStartProfiler()
413-
return True
414-
else:
415-
logger.debug("Internal profiler is not available, skip...")
416-
return False
417-
418-
419-
def sampleProfiler(stop: bool = False) -> None:
420-
"""
421-
Starts internal sample profiler.
422-
423-
Args:
424-
stop: Bool to be passed into sample profiler.
425-
Returns:
426-
None
427-
"""
428-
if has_internal_libs:
429-
fbSampleProfiler(stop)
430-
else:
431-
logger.debug("Internal profiler is not available, skip...")
432-
433-
434390
class commsArgs:
435391
"""
436392
This class contains all of the args that we can use to perform a single collective.

0 commit comments

Comments
 (0)