Skip to content

Commit 5506baa

Browse files
AnantGulatipytorchmergebot
authored andcommitted
Refactoring FSDP2 (_composable/fsdp) test cases to be device agnostic (#149848)
The motivation for this PR is refactor existing test cases in the folder test/distributed/_composable/fsdp/ or fsdp2(as referred to in torch titan) to be device agnostic such that any accelerator type is supported (for eg. CUDA, HPU, XPU etc) The changes are in line with previously merged changes for fsdp (present in the folder test/distributed/fsdp/ ) test cases: #139184 Pull Request resolved: #149848 Approved by: https://github.com/kwen2501, https://github.com/guangyey
1 parent 6f835a4 commit 5506baa

16 files changed

+539
-362
lines changed

test/distributed/_composable/fsdp/test_fully_shard_autograd.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@
44
import copy
55
import functools
66
import itertools
7-
import unittest
87
from typing import Any, Optional, Union
98

109
import torch
1110
import torch.distributed as dist
1211
import torch.nn as nn
1312
from torch.distributed.fsdp import fully_shard
1413
from torch.nn.parallel.scatter_gather import _is_namedtuple
15-
from torch.testing._internal.common_cuda import TEST_CUDA
1614
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1715
from torch.testing._internal.common_fsdp import (
1816
check_sharded_parity,
1917
DoubleLinear,
2018
FSDPTest,
2119
FSDPTestMultiThread,
20+
get_devtype,
2221
MLP,
2322
)
2423
from torch.testing._internal.common_utils import run_tests
@@ -28,10 +27,13 @@
2827
)
2928

3029

30+
device_type = torch.device(get_devtype())
31+
32+
3133
class TestFullyShardAutograd(FSDPTest):
3234
@property
3335
def world_size(self) -> int:
34-
return min(4, torch.cuda.device_count())
36+
return min(4, torch.get_device_module(device_type).device_count())
3537

3638
def _reduce_1d_partial_grads(
3739
self, module: nn.Module, group: Optional[dist.ProcessGroup] = None
@@ -58,7 +60,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
5860
local_batch_size = 2
5961
global_batch_size, dim = (self.world_size * local_batch_size, 24)
6062
model = DoubleLinear(dim=dim, use_second_linear=True)
61-
ref_model = copy.deepcopy(model).cuda()
63+
ref_model = copy.deepcopy(model).to(device_type)
6264
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
6365
fully_shard(model, reshard_after_forward=reshard_after_forward)
6466
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
@@ -68,7 +70,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
6870
for iter_idx in range(10):
6971
# Use all forward outputs in the loss/backward for the first half
7072
# of the iterations and only the 1st forward output for the rest
71-
global_inp = torch.rand((global_batch_size, dim), device="cuda")
73+
global_inp = torch.rand((global_batch_size, dim), device=device_type)
7274
local_inp = global_inp[
7375
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
7476
].detach()
@@ -104,7 +106,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
104106
local_batch_size, dim = (2, 24)
105107
global_batch_size = self.world_size * local_batch_size
106108
model = DoubleLinear(dim=dim, use_second_linear=False)
107-
ref_model = copy.deepcopy(model).cuda()
109+
ref_model = copy.deepcopy(model).to(device_type)
108110
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
109111
fully_shard(model.lin2, reshard_after_forward=reshard_after_forward)
110112
fully_shard(model, reshard_after_forward=reshard_after_forward)
@@ -113,7 +115,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
113115

114116
torch.manual_seed(1) # same on all ranks
115117
for iter_idx in range(10):
116-
global_inp = torch.rand((global_batch_size, dim), device="cuda")
118+
global_inp = torch.rand((global_batch_size, dim), device=device_type)
117119
local_inp = global_inp[
118120
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
119121
].detach()
@@ -214,7 +216,7 @@ def forward(self, x: torch.Tensor):
214216
Module(dim),
215217
FromContainerType(container_type),
216218
)
217-
ref_model = copy.deepcopy(model).cuda()
219+
ref_model = copy.deepcopy(model).to(device_type)
218220
for module in model:
219221
fully_shard(module)
220222
fully_shard(model)
@@ -223,7 +225,7 @@ def forward(self, x: torch.Tensor):
223225

224226
torch.manual_seed(1) # same on all ranks
225227
for iter_idx in range(10):
226-
global_inp = torch.rand((global_batch_size, dim), device="cuda")
228+
global_inp = torch.rand((global_batch_size, dim), device=device_type)
227229
local_inp = global_inp[
228230
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
229231
].detach()
@@ -245,7 +247,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
245247
def world_size(self) -> int:
246248
return 2
247249

248-
@unittest.skipIf(not TEST_CUDA, "no cuda")
250+
@skip_if_lt_x_gpu(1)
249251
def test_post_acc_grad_hook_runs(self):
250252
param_name_to_hook_count = collections.defaultdict(int)
251253

@@ -260,7 +262,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
260262
param_hook = functools.partial(hook, param_name)
261263
param.register_post_accumulate_grad_hook(param_hook)
262264

263-
inp = torch.randn((2, 8), device="cuda")
265+
inp = torch.randn((2, 8), device=device_type)
264266
model(inp).sum().backward()
265267
param_names = {param_name for param_name, _ in model.named_parameters()}
266268
self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
@@ -271,7 +273,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
271273
class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
272274
@property
273275
def world_size(self) -> int:
274-
return min(torch.cuda.device_count(), 2)
276+
return min(torch.get_device_module(device_type).device_count(), 2)
275277

276278
@skip_if_lt_x_gpu(2)
277279
def test_post_acc_grad_hook_optim_parity(self):
@@ -283,7 +285,7 @@ def test_post_acc_grad_hook_optim_parity(self):
283285
model_args = ModelArgs(dropout_p=0.0)
284286
model = Transformer(model_args)
285287

286-
ref_model = copy.deepcopy(model).cuda()
288+
ref_model = copy.deepcopy(model).to(device_type)
287289
for module in itertools.chain(ref_model.layers, [ref_model]):
288290
fully_shard(module)
289291
optim_kwargs = {"lr": 1e-2, "foreach": False}
@@ -312,7 +314,7 @@ def optim_hook(param: nn.Parameter) -> None:
312314
param.register_post_accumulate_grad_hook(optim_hook)
313315

314316
torch.manual_seed(42 + self.rank)
315-
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
317+
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
316318
for _ in range(10):
317319
ref_loss = ref_model(inp).sum()
318320
ref_loss.backward()

test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.distributed.fsdp import fully_shard
1212
from torch.distributed.tensor.debug import CommDebugMode
1313
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14-
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
14+
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLPStack
1515
from torch.testing._internal.common_utils import run_tests
1616
from torch.testing._internal.distributed._tensor.common_dtensor import (
1717
ModelArgs,
@@ -20,6 +20,9 @@
2020
)
2121

2222

23+
device_type = torch.device(get_devtype())
24+
25+
2326
class _TestClipGradNormBase(FSDPTest):
2427
def _test_clip_grad_norm(
2528
self,
@@ -33,7 +36,7 @@ def _test_clip_grad_norm(
3336
dp_mesh: Optional[DeviceMesh] = None,
3437
):
3538
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
36-
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
39+
dp_mesh = dp_mesh or init_device_mesh(device_type.type, (self.world_size,))
3740
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
3841
for _ in range(10):
3942
ref_optim.zero_grad()
@@ -91,22 +94,24 @@ def _test_clip_grad_norm(
9194
class TestClipGradNormWorldSize2(_TestClipGradNormBase):
9295
@property
9396
def world_size(self) -> int:
94-
return min(torch.cuda.device_count(), 2)
97+
return min(torch.get_device_module(device_type).device_count(), 2)
9598

9699
@skip_if_lt_x_gpu(2)
97100
def test_clip_grad_norm_1d(self):
98101
for norm_type in (2, 1, float("inf")):
99102
torch.manual_seed(42)
100103
model_args = ModelArgs(dropout_p=0.0)
101104
model = Transformer(model_args)
102-
ref_model = replicate(copy.deepcopy(model).cuda())
105+
ref_model = replicate(copy.deepcopy(model).to(device_type))
103106
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
104107
for module in model.modules():
105108
if isinstance(module, TransformerBlock):
106109
fully_shard(module)
107110
fully_shard(model)
108111
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
109-
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
112+
inp = torch.randint(
113+
0, model.model_args.vocab_size, (3, 16), device=device_type
114+
)
110115
self._test_clip_grad_norm(
111116
1, norm_type, ref_model, ref_optim, model, optim, inp
112117
)
@@ -115,14 +120,14 @@ def test_clip_grad_norm_1d(self):
115120
class TestClipGradNormWorldSize4(_TestClipGradNormBase):
116121
@property
117122
def world_size(self) -> int:
118-
return min(torch.cuda.device_count(), 4)
123+
return min(torch.get_device_module(device_type).device_count(), 4)
119124

120125
@skip_if_lt_x_gpu(4)
121126
def test_clip_grad_norm_2d(self):
122127
for norm_type in (2, 1, 3, float("inf")):
123128
dp_size = 2
124129
global_mesh = init_device_mesh(
125-
"cuda",
130+
device_type.type,
126131
(dp_size, self.world_size // dp_size),
127132
mesh_dim_names=("dp", "tp"),
128133
)
@@ -132,7 +137,7 @@ def test_clip_grad_norm_2d(self):
132137
# has some more significant numeric differences from the TP
133138
model = MLPStack(16, with_seq_parallel=True)
134139
ref_model = replicate(
135-
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
140+
copy.deepcopy(model).to(device_type), process_group=dp_mesh.get_group()
136141
)
137142
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
138143
model.parallelize(
@@ -142,7 +147,7 @@ def test_clip_grad_norm_2d(self):
142147
reshard_after_forward=True,
143148
)
144149
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
145-
inp = torch.randn(2, 16, device="cuda")
150+
inp = torch.randn(2, 16, device=device_type)
146151
self._test_clip_grad_norm(
147152
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
148153
)

0 commit comments

Comments
 (0)