4
4
import copy
5
5
import functools
6
6
import itertools
7
- import unittest
8
7
from typing import Any , Optional , Union
9
8
10
9
import torch
11
10
import torch .distributed as dist
12
11
import torch .nn as nn
13
12
from torch .distributed .fsdp import fully_shard
14
13
from torch .nn .parallel .scatter_gather import _is_namedtuple
15
- from torch .testing ._internal .common_cuda import TEST_CUDA
16
14
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
17
15
from torch .testing ._internal .common_fsdp import (
18
16
check_sharded_parity ,
19
17
DoubleLinear ,
20
18
FSDPTest ,
21
19
FSDPTestMultiThread ,
20
+ get_devtype ,
22
21
MLP ,
23
22
)
24
23
from torch .testing ._internal .common_utils import run_tests
28
27
)
29
28
30
29
30
+ device_type = torch .device (get_devtype ())
31
+
32
+
31
33
class TestFullyShardAutograd (FSDPTest ):
32
34
@property
33
35
def world_size (self ) -> int :
34
- return min (4 , torch .cuda .device_count ())
36
+ return min (4 , torch .get_device_module ( device_type ) .device_count ())
35
37
36
38
def _reduce_1d_partial_grads (
37
39
self , module : nn .Module , group : Optional [dist .ProcessGroup ] = None
@@ -58,7 +60,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
58
60
local_batch_size = 2
59
61
global_batch_size , dim = (self .world_size * local_batch_size , 24 )
60
62
model = DoubleLinear (dim = dim , use_second_linear = True )
61
- ref_model = copy .deepcopy (model ).cuda ( )
63
+ ref_model = copy .deepcopy (model ).to ( device_type )
62
64
fully_shard (model .lin1 , reshard_after_forward = reshard_after_forward )
63
65
fully_shard (model , reshard_after_forward = reshard_after_forward )
64
66
ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 )
@@ -68,7 +70,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
68
70
for iter_idx in range (10 ):
69
71
# Use all forward outputs in the loss/backward for the first half
70
72
# of the iterations and only the 1st forward output for the rest
71
- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
73
+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
72
74
local_inp = global_inp [
73
75
self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
74
76
].detach ()
@@ -104,7 +106,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
104
106
local_batch_size , dim = (2 , 24 )
105
107
global_batch_size = self .world_size * local_batch_size
106
108
model = DoubleLinear (dim = dim , use_second_linear = False )
107
- ref_model = copy .deepcopy (model ).cuda ( )
109
+ ref_model = copy .deepcopy (model ).to ( device_type )
108
110
fully_shard (model .lin1 , reshard_after_forward = reshard_after_forward )
109
111
fully_shard (model .lin2 , reshard_after_forward = reshard_after_forward )
110
112
fully_shard (model , reshard_after_forward = reshard_after_forward )
@@ -113,7 +115,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
113
115
114
116
torch .manual_seed (1 ) # same on all ranks
115
117
for iter_idx in range (10 ):
116
- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
118
+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
117
119
local_inp = global_inp [
118
120
self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
119
121
].detach ()
@@ -214,7 +216,7 @@ def forward(self, x: torch.Tensor):
214
216
Module (dim ),
215
217
FromContainerType (container_type ),
216
218
)
217
- ref_model = copy .deepcopy (model ).cuda ( )
219
+ ref_model = copy .deepcopy (model ).to ( device_type )
218
220
for module in model :
219
221
fully_shard (module )
220
222
fully_shard (model )
@@ -223,7 +225,7 @@ def forward(self, x: torch.Tensor):
223
225
224
226
torch .manual_seed (1 ) # same on all ranks
225
227
for iter_idx in range (10 ):
226
- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
228
+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
227
229
local_inp = global_inp [
228
230
self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
229
231
].detach ()
@@ -245,7 +247,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
245
247
def world_size (self ) -> int :
246
248
return 2
247
249
248
- @unittest . skipIf ( not TEST_CUDA , "no cuda" )
250
+ @skip_if_lt_x_gpu ( 1 )
249
251
def test_post_acc_grad_hook_runs (self ):
250
252
param_name_to_hook_count = collections .defaultdict (int )
251
253
@@ -260,7 +262,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
260
262
param_hook = functools .partial (hook , param_name )
261
263
param .register_post_accumulate_grad_hook (param_hook )
262
264
263
- inp = torch .randn ((2 , 8 ), device = "cuda" )
265
+ inp = torch .randn ((2 , 8 ), device = device_type )
264
266
model (inp ).sum ().backward ()
265
267
param_names = {param_name for param_name , _ in model .named_parameters ()}
266
268
self .assertEqual (param_names , set (param_name_to_hook_count .keys ()))
@@ -271,7 +273,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
271
273
class TestFullyShardPostAccGradHookMultiProcess (FSDPTest ):
272
274
@property
273
275
def world_size (self ) -> int :
274
- return min (torch .cuda .device_count (), 2 )
276
+ return min (torch .get_device_module ( device_type ) .device_count (), 2 )
275
277
276
278
@skip_if_lt_x_gpu (2 )
277
279
def test_post_acc_grad_hook_optim_parity (self ):
@@ -283,7 +285,7 @@ def test_post_acc_grad_hook_optim_parity(self):
283
285
model_args = ModelArgs (dropout_p = 0.0 )
284
286
model = Transformer (model_args )
285
287
286
- ref_model = copy .deepcopy (model ).cuda ( )
288
+ ref_model = copy .deepcopy (model ).to ( device_type )
287
289
for module in itertools .chain (ref_model .layers , [ref_model ]):
288
290
fully_shard (module )
289
291
optim_kwargs = {"lr" : 1e-2 , "foreach" : False }
@@ -312,7 +314,7 @@ def optim_hook(param: nn.Parameter) -> None:
312
314
param .register_post_accumulate_grad_hook (optim_hook )
313
315
314
316
torch .manual_seed (42 + self .rank )
315
- inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = "cuda" )
317
+ inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = device_type )
316
318
for _ in range (10 ):
317
319
ref_loss = ref_model (inp ).sum ()
318
320
ref_loss .backward ()
0 commit comments