Skip to content

Commit 422f307

Browse files
rahulsingh-intelpytorchmergebot
authored andcommitted
General Changes for multi accelerators
1 parent b30bad7 commit 422f307

17 files changed

+743
-601
lines changed

test/distributed/fsdp/test_fsdp_comm_hooks.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,42 @@
1212
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
1313
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
1414
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
15+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1516
from torch.testing._internal.common_distributed import (
1617
requires_nccl,
1718
requires_nccl_version,
1819
skip_but_pass_in_sandcastle_if,
1920
skip_if_lt_x_gpu,
2021
)
21-
from torch.testing._internal.common_fsdp import FSDPTest
22+
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
2223
from torch.testing._internal.common_utils import (
23-
instantiate_parametrized_tests,
2424
parametrize,
2525
run_tests,
26+
TEST_CUDA,
27+
TEST_HPU,
2628
)
2729

2830

31+
device_type = torch.device(get_devtype())
32+
33+
2934
if not dist.is_available():
3035
print("Distributed not available, skipping tests", file=sys.stderr)
3136
sys.exit(0)
3237

3338
# bfloat16 is only supported by CUDA 11+
34-
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
35-
torch.version.cuda is not None or torch.version.hip is not None
39+
BFLOAT16_AVAILABLE = (
40+
True
41+
if (TEST_CUDA and (torch.version.cuda is not None or torch.version.hip is not None))
42+
or TEST_HPU
43+
else False
3644
)
3745

3846

3947
class Net(nn.Module):
4048
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
4149
# to ensure determinism
4250
torch.manual_seed(0)
43-
torch.cuda.manual_seed(0)
4451
super().__init__()
4552

4653
if has_wrapping:
@@ -50,12 +57,12 @@ def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
5057
nn.ReLU(),
5158
FSDP(
5259
nn.Linear(16, 8),
53-
device_id=torch.cuda.current_device(),
60+
device_id=device_type,
5461
sharding_strategy=sharding_strategy,
5562
mixed_precision=mixed_precision,
5663
),
5764
),
58-
device_id=torch.cuda.current_device(),
65+
device_id=device_type,
5966
sharding_strategy=sharding_strategy,
6067
mixed_precision=mixed_precision,
6168
)
@@ -134,13 +141,13 @@ def test_default_communication_hook_behavior(
134141
"""
135142
out_dim = self.world_size
136143
net = torch.nn.Linear(1, out_dim, bias=False)
137-
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
144+
inpt = torch.tensor([self.rank]).float().to(device_type.type)
138145

139146
net_default_hook = FSDP(
140147
net,
141-
device_id=torch.cuda.current_device(),
148+
device_id=device_type,
142149
sharding_strategy=sharding_strategy,
143-
).to(self.rank)
150+
).to(device_type.type)
144151

145152
# Check that by default, `_comm_hook` is None
146153
for entry in FSDP.fsdp_modules(net_default_hook):
@@ -172,13 +179,12 @@ def _get_submodules(self, fsdp_net):
172179
]
173180

174181
def _init_model(self, core, sharding_strategy, mixed_precision=None):
175-
device = torch.device("cuda")
176182
return FSDP(
177183
core,
178-
device_id=torch.cuda.current_device(),
184+
device_id=device_type,
179185
sharding_strategy=sharding_strategy,
180186
mixed_precision=mixed_precision,
181-
).to(device)
187+
).to(device_type)
182188

183189
@skip_if_lt_x_gpu(2)
184190
@parametrize("has_wrapping", [True, False])
@@ -277,9 +283,10 @@ def test_registering_hook_hybrid_strategy(self):
277283
ShardingStrategy.HYBRID_SHARD,
278284
ShardingStrategy._HYBRID_SHARD_ZERO2,
279285
):
280-
model = Net(False, None, None).cuda()
286+
model = Net(False, None, None).to(device_type)
281287
fsdp_model = FSDP(
282288
model,
289+
device_id=device_type,
283290
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
284291
sharding_strategy=sharding_strategy,
285292
)
@@ -337,7 +344,6 @@ def _check_low_precision_hook(
337344
):
338345
# keep everything deterministic for input data
339346
torch.manual_seed(0)
340-
torch.cuda.manual_seed(0)
341347

342348
fsdp_with_hook = self._init_model(
343349
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
@@ -359,7 +365,7 @@ def _check_low_precision_hook(
359365
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
360366
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
361367

362-
in_data = torch.rand(16, 8).cuda()
368+
in_data = torch.rand(16, 8).to(device_type)
363369
fsdp_with_hook.train()
364370
fsdp_with_mp.train()
365371
loss_hook = fsdp_with_hook(in_data).sum()
@@ -426,7 +432,7 @@ def test_bf16_hook(
426432
)
427433

428434

429-
instantiate_parametrized_tests(TestCommunicationHooks)
430-
435+
devices = ("cuda", "hpu")
436+
instantiate_device_type_tests(TestCommunicationHooks, globals(), only_for=devices)
431437
if __name__ == "__main__":
432438
run_tests()

test/distributed/fsdp/test_fsdp_flatten_params.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@
1111
FlatParamShardMetadata,
1212
HandleShardingStrategy,
1313
)
14+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1415
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
15-
from torch.testing._internal.common_fsdp import FSDPTest
16+
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
1617
from torch.testing._internal.common_utils import (
17-
instantiate_parametrized_tests,
1818
parametrize,
1919
run_tests,
2020
TEST_WITH_DEV_DBG_ASAN,
2121
)
2222

2323

24+
device_type = torch.device(get_devtype())
25+
2426
if not dist.is_available():
2527
print("Distributed not available, skipping tests", file=sys.stderr)
2628
sys.exit(0)
@@ -45,7 +47,7 @@ def world_size(self) -> int:
4547

4648
def _get_default_config(self):
4749
return {
48-
"device": torch.device("cuda"),
50+
"device": torch.device(device_type),
4951
"sharding_strategy": HandleShardingStrategy.FULL_SHARD,
5052
"offload_params": False,
5153
"mp_param_dtype": None,
@@ -103,8 +105,8 @@ def _test_partial_flattening(self, half: bool):
103105
params_to_flatten = encoder_1_params + decoder_0_params
104106
num_params = [len(encoder_1_params), len(decoder_0_params)]
105107
numel_to_flatten = sum(p.numel() for p in params_to_flatten)
106-
module.encoder.layers[1] = FSDP(module.encoder.layers[1])
107-
module.decoder.layers[0] = FSDP(module.decoder.layers[0])
108+
module.encoder.layers[1] = FSDP(module.encoder.layers[1], device_id=device_type)
109+
module.decoder.layers[0] = FSDP(module.decoder.layers[0], device_id=device_type)
108110
flat_params = [
109111
module.encoder.layers[1]._flat_param,
110112
module.decoder.layers[0]._flat_param,
@@ -173,7 +175,7 @@ def test_empty_module(self):
173175
module = self._get_empty_module()
174176
in_data = torch.rand(1)
175177
ref_out = module(in_data)
176-
fsdp_module = FSDP(module)
178+
fsdp_module = FSDP(module, device_id=device_type)
177179
self.assertEqual(len(list(fsdp_module.parameters())), 0)
178180
self.assertIsNone(fsdp_module._flat_param)
179181
fsdp_out = fsdp_module(in_data)
@@ -270,9 +272,9 @@ def _test_output_with_shared_params(self, half: bool):
270272
self._test_output(module)
271273

272274
def _test_output(self, module: nn.Module):
273-
module = module.to(self.rank)
275+
module = module.to(device_type)
274276
ref_output = self._get_output(module)
275-
fsdp_module = FSDP(module)
277+
fsdp_module = FSDP(module, device_id=device_type)
276278
fsdp_output = self._get_output(fsdp_module)
277279
self.assertEqual(ref_output, fsdp_output)
278280

@@ -295,14 +297,14 @@ def test_pnorm_after_step_with_shared_params(self):
295297
)
296298

297299
def _test_pnorm_after_step_with_shared_params(self, half: bool):
298-
module = self._get_shared_params_transformer().to(self.rank)
300+
module = self._get_shared_params_transformer().to(device_type)
299301
if half:
300302
module = module.half()
301303
ref_pnorm_after_step = self._get_pnorm_after_step(module)
302-
module = self._get_shared_params_transformer().to(self.rank) # recreate
304+
module = self._get_shared_params_transformer().to(device_type) # recreate
303305
if half:
304306
module = module.half()
305-
fsdp_module = FSDP(module)
307+
fsdp_module = FSDP(module, device_id=device_type)
306308
fsdp_pnorm_after_step = self._get_pnorm_after_step(fsdp_module)
307309
self.assertEqual(ref_pnorm_after_step, fsdp_pnorm_after_step)
308310

@@ -648,7 +650,7 @@ def test_flat_param_shard_metadata_with_memory_format(self, memory_format):
648650
)
649651

650652

651-
instantiate_parametrized_tests(TestFlattenParams)
652-
653+
devices = ("cuda", "hpu")
654+
instantiate_device_type_tests(TestFlattenParams, globals(), only_for=devices)
653655
if __name__ == "__main__":
654656
run_tests()

test/distributed/fsdp/test_fsdp_freezing_weights.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010
from torch import distributed as dist
1111
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1212
from torch.nn.parallel import DistributedDataParallel
13+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1314
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14-
from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
15+
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, get_full_params
1516
from torch.testing._internal.common_utils import (
16-
instantiate_parametrized_tests,
1717
parametrize,
1818
run_tests,
1919
TEST_WITH_DEV_DBG_ASAN,
2020
)
2121

2222

23+
device_type = torch.device(get_devtype())
24+
2325
if not dist.is_available():
2426
print("Distributed not available, skipping tests", file=sys.stderr)
2527
sys.exit(0)
@@ -47,7 +49,6 @@ def __init__(
4749
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
4850
nn.Flatten(),
4951
)
50-
self.device = torch.cuda.current_device()
5152
self.head = nn.Linear(64, 10)
5253
if with_fsdp and freeze_after_wrap_fsdp:
5354
self.fsdp_wrap(fsdp_kwargs)
@@ -56,6 +57,7 @@ def __init__(
5657
)
5758

5859
def fsdp_wrap(self, fsdp_kwargs):
60+
fsdp_kwargs = {"device_id": device_type}
5961
self.trunk = FSDP(self.trunk, **fsdp_kwargs)
6062
self.head = FSDP(self.head, **fsdp_kwargs)
6163

@@ -90,6 +92,7 @@ def __init__(
9092
)
9193

9294
def fsdp_wrap(self, fsdp_kwargs):
95+
fsdp_kwargs = {"device_id": device_type}
9396
for name, child in self.trunk.named_children():
9497
wrapped_child = FSDP(child, **fsdp_kwargs)
9598
setattr(self.trunk, name, wrapped_child)
@@ -145,15 +148,15 @@ def _dist_train(
145148
forward_prefetch,
146149
):
147150
torch.manual_seed(0)
148-
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
151+
batch = torch.randn(size=(2, 3, 224, 224)).to(device_type)
149152

150153
fsdp_kwargs = {
151-
"device_id": self.rank,
152154
"forward_prefetch": forward_prefetch,
155+
"device_id": device_type,
153156
}
154157

155158
ddp_kwargs = {
156-
"device_ids": [self.rank],
159+
"device_ids": [device_type],
157160
"find_unused_parameters": True if disable_autograd else False,
158161
}
159162

@@ -164,7 +167,7 @@ def _dist_train(
164167
disable_autograd,
165168
fsdp_kwargs,
166169
)
167-
model = model.cuda()
170+
model = model.to(device_type)
168171

169172
# freezing the trunk using requires_grad.
170173
if freezing_method == FreezingMethod.RequiresGrad:
@@ -178,7 +181,7 @@ def _dist_train(
178181
else:
179182
model = DistributedDataParallel(model, **ddp_kwargs)
180183

181-
target = torch.tensor([0, 1], dtype=torch.long).cuda()
184+
target = torch.tensor([0, 1], dtype=torch.long).to(device_type)
182185
criterion = nn.CrossEntropyLoss()
183186
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
184187

@@ -245,7 +248,7 @@ def test_freezing_weights(
245248
self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad)
246249

247250

248-
instantiate_parametrized_tests(TestFreezingWeights)
249-
251+
devices = ("cuda", "hpu")
252+
instantiate_device_type_tests(TestFreezingWeights, globals(), only_for=devices)
250253
if __name__ == "__main__":
251254
run_tests()

0 commit comments

Comments
 (0)