Skip to content

Refactoring Distributed test cases to be device agnostic [1/n] #145222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 20 additions & 43 deletions test/distributed/_composable/test_replicate_with_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import contextlib
import functools
import os
import unittest
from copy import deepcopy
from typing import Callable, Optional
from typing import Callable, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -27,17 +26,20 @@
)
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
DistributedTestBase,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils.checkpoint import checkpoint


device_type = str(get_devtype())

DIM = 2000


Expand Down Expand Up @@ -72,54 +74,29 @@ def inner_compiler(gm_, example_inputs_):
return _compiler_fn


class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase):
"""
A version of MultiProcessTestCase that derives from the Inductor TestCase
to handle isolation of the inductor cache dir.
"""


class ReplicateTest(MultiProcessInductorTestCase):
# TODO: consider using all devices? The min(2, ...) here would limit the
# test to always run on 2 GPUs only.
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
return min(2, torch.get_device_module(device_type).device_count())

def _test_compile(
self,
*,
use_gpu: bool,
no_sync: bool,
setup_func: Optional[Callable] = None,
no_inductor: bool = False,
no_compile_forward: bool = False,
checkpoint: bool = False,
device: Union[str, torch.device],
):
backend = "nccl" if use_gpu else "gloo"
dist.init_process_group(
backend=backend,
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
if use_gpu:
torch.cuda.set_device(f"cuda:{self.rank}")
device = torch.device("cuda")
else:
device = torch.device("cpu")

self.create_pg(device)
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
if no_compile_forward
Expand Down Expand Up @@ -202,22 +179,23 @@ def _test_compile(
self.assertEqual(
tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
)
dist.destroy_process_group()

def test_compile_cpu(self):
# Test the coalesced_op with CPU.
torch._inductor.config._fuse_ddp_communication_passes = [
"fuse_ddp_with_coalesced_op",
"schedule_comm_wait",
]
self._test_compile(use_gpu=False, no_sync=False)
self._test_compile(no_sync=False, device="cpu")

def test_compile_cpu_no_sync(self):
# Test the coalesced_op with CPU.
torch._inductor.config._fuse_ddp_communication_passes = [
"fuse_ddp_with_coalesced_op",
"schedule_comm_wait",
]
self._test_compile(use_gpu=False, no_sync=True)
self._test_compile(no_sync=True, device="cpu")

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
Expand All @@ -226,7 +204,7 @@ def test_compile_cpu_no_sync(self):
reorder_for_locality=False, reorder_for_peak_memory=False
)
def test_compile_gpu(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
self._test_compile(no_sync=False, checkpoint=False, device=device_type)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
Expand All @@ -235,15 +213,14 @@ def test_compile_gpu(self):
reorder_for_locality=False, reorder_for_peak_memory=False
)
def test_compile_gpu_ac(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
self._test_compile(no_sync=False, checkpoint=True, device=device_type)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_bf16(self):
# Check device capability wrt bf16
device = torch.device("cuda", self.rank % torch.cuda.device_count())
if not sm_is_or_higher_than(device, 8, 0):
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
self.skipTest("bf16 requires sm >= 8.0")

def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
Expand All @@ -254,7 +231,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
None, ddp_default_hooks.bf16_compress_hook
)

self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
self._test_compile(no_sync=False, setup_func=setup, device=device_type)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
Expand All @@ -270,14 +247,14 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:

# TODO: figure out why we need to disable Inductor to avoid test errors.
self._test_compile(
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
no_sync=False, setup_func=setup, no_inductor=True, device=device_type
)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_backward_only(self):
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)

def _test_bucketing(self, init_process_group=True, loop=1):
if init_process_group:
Expand Down Expand Up @@ -397,7 +374,7 @@ def setUp(self):
# Hmm, why a specific set_device call for rank 0?
self.rank = 0
self.world_size = 4
torch.cuda.set_device("cuda:0")
torch.get_device_module(device_type).set_device(device_type)

store = FakeStore()
dist.init_process_group(
Expand All @@ -419,7 +396,7 @@ def test_ddp_tp(self):
ref_model = Net()
compiled_replicate_model = deepcopy(ref_model)
mesh_2d = init_device_mesh(
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
Expand Down
Loading