Skip to content

Commit 045a9f8

Browse files
ankurneogTharinduRusira
authored andcommitted
Add Intel Gaudi device/HPU to auto load in instantiate_device_type_tests (pytorch#126970)
### Motivation Intel Gaudi accelerator (device name hpu) is seen to have good pass rate with the pytorch framework UTs , however being an out-of-tree device, we face challenges in adapting the device to natively run the existing pytorch UTs under pytorch/test. The UTs however is a good indicator of the device stack health and as such we run them regularly with adaptations. Although we can add Gaudi/HPU device to generate the device specific tests using the TORCH_TEST_DEVICES environment variable, we miss out on lot of features such as executing for specific dtypes, skipping and overriding opInfo. With significant changes introduced every Pytorch release maintaining these adaptations become difficult and time consuming. Hence with this PR we introduce Gaudi device in common_device_type framework, so that the tests are instantiated for Gaudi when the library is loaded. The eventual goal is to introduce Gaudi out-of-tree support as equivalent to in-tree devices ### Changes Add HPUTestBase of type DeviceTypeTestBase specifying appropriate attributes for Gaudi/HPU. Include code to check if intel Gaudi Software library is loaded and if so, add the device to the list of devices considered for instantiation of device type tests ### Additional Context please refer the following RFC : pytorch/rfcs#63 Pull Request resolved: pytorch#126970 Approved by: https://github.com/albanD
1 parent 613524b commit 045a9f8

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

torch/testing/_internal/common_device_type.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
1717
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
18-
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \
18+
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \
1919
_TestParametrizer, compose_parametrize_fns, dtype_name, \
2020
TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \
2121
get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \
@@ -590,6 +590,18 @@ def setUpClass(cls):
590590
def _should_stop_test_suite(self):
591591
return False
592592

593+
class HPUTestBase(DeviceTypeTestBase):
594+
device_type = 'hpu'
595+
primary_device: ClassVar[str]
596+
597+
@classmethod
598+
def get_primary_device(cls):
599+
return cls.primary_device
600+
601+
@classmethod
602+
def setUpClass(cls):
603+
cls.primary_device = 'hpu:0'
604+
593605
class PrivateUse1TestBase(DeviceTypeTestBase):
594606
primary_device: ClassVar[str]
595607
device_mod = None
@@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l
701713
test_bases.append(MPSTestBase)
702714
if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases:
703715
test_bases.append(XPUTestBase)
716+
if TEST_HPU and HPUTestBase not in test_bases:
717+
test_bases.append(HPUTestBase)
704718
# Filter out the device types based on user inputs
705719
desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)
706720
if include_lazy:
@@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf):
10601074
def __init__(self, dep, reason):
10611075
super().__init__(dep, reason, device_type='mps')
10621076

1077+
class skipHPUIf(skipIf):
1078+
def __init__(self, dep, reason):
1079+
super().__init__(dep, reason, device_type='hpu')
1080+
10631081
# Skips a test on XLA if the condition is true.
10641082
class skipXLAIf(skipIf):
10651083

@@ -1343,6 +1361,9 @@ def onlyMPS(fn):
13431361
def onlyXPU(fn):
13441362
return onlyOn('xpu')(fn)
13451363

1364+
def onlyHPU(fn):
1365+
return onlyOn('hpu')(fn)
1366+
13461367
def onlyPRIVATEUSE1(fn):
13471368
device_type = torch._C._get_privateuse1_backend_name()
13481369
device_mod = getattr(torch, device_type, None)
@@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn):
14011422
def expectedFailureXLA(fn):
14021423
return expectedFailure('xla')(fn)
14031424

1425+
def expectedFailureHPU(fn):
1426+
return expectedFailure('hpu')(fn)
1427+
14041428
# Skips a test on CPU if LAPACK is not available.
14051429
def skipCPUIfNoLapack(fn):
14061430
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
@@ -1578,6 +1602,9 @@ def skipXLA(fn):
15781602
def skipMPS(fn):
15791603
return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
15801604

1605+
def skipHPU(fn):
1606+
return skipHPUIf(True, "test doesn't work on HPU backend")(fn)
1607+
15811608
def skipPRIVATEUSE1(fn):
15821609
return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
15831610

torch/testing/_internal/common_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,7 @@ def TemporaryDirectoryName(suffix=None):
12361236
TEST_MKL = torch.backends.mkl.is_available()
12371237
TEST_MPS = torch.backends.mps.is_available()
12381238
TEST_XPU = torch.xpu.is_available()
1239+
TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
12391240
TEST_CUDA = torch.cuda.is_available()
12401241
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
12411242
custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()
@@ -1622,6 +1623,15 @@ def wrapper(*args, **kwargs):
16221623
fn(*args, **kwargs)
16231624
return wrapper
16241625

1626+
def skipIfHpu(fn):
1627+
@wraps(fn)
1628+
def wrapper(*args, **kwargs):
1629+
if TEST_HPU:
1630+
raise unittest.SkipTest("test doesn't currently work with HPU")
1631+
else:
1632+
fn(*args, **kwargs)
1633+
return wrapper
1634+
16251635
# Skips a test on CUDA if ROCm is available and its version is lower than requested.
16261636
def skipIfRocmVersionLessThan(version=None):
16271637
def dec_fn(fn):

0 commit comments

Comments
 (0)