|
15 | 15 | import torch
|
16 | 16 | from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
|
17 | 17 | skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
|
18 |
| - IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \ |
| 18 | + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \ |
19 | 19 | _TestParametrizer, compose_parametrize_fns, dtype_name, \
|
20 | 20 | TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \
|
21 | 21 | get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \
|
@@ -590,6 +590,18 @@ def setUpClass(cls):
|
590 | 590 | def _should_stop_test_suite(self):
|
591 | 591 | return False
|
592 | 592 |
|
| 593 | +class HPUTestBase(DeviceTypeTestBase): |
| 594 | + device_type = 'hpu' |
| 595 | + primary_device: ClassVar[str] |
| 596 | + |
| 597 | + @classmethod |
| 598 | + def get_primary_device(cls): |
| 599 | + return cls.primary_device |
| 600 | + |
| 601 | + @classmethod |
| 602 | + def setUpClass(cls): |
| 603 | + cls.primary_device = 'hpu:0' |
| 604 | + |
593 | 605 | class PrivateUse1TestBase(DeviceTypeTestBase):
|
594 | 606 | primary_device: ClassVar[str]
|
595 | 607 | device_mod = None
|
@@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l
|
701 | 713 | test_bases.append(MPSTestBase)
|
702 | 714 | if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases:
|
703 | 715 | test_bases.append(XPUTestBase)
|
| 716 | + if TEST_HPU and HPUTestBase not in test_bases: |
| 717 | + test_bases.append(HPUTestBase) |
704 | 718 | # Filter out the device types based on user inputs
|
705 | 719 | desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)
|
706 | 720 | if include_lazy:
|
@@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf):
|
1060 | 1074 | def __init__(self, dep, reason):
|
1061 | 1075 | super().__init__(dep, reason, device_type='mps')
|
1062 | 1076 |
|
| 1077 | +class skipHPUIf(skipIf): |
| 1078 | + def __init__(self, dep, reason): |
| 1079 | + super().__init__(dep, reason, device_type='hpu') |
| 1080 | + |
1063 | 1081 | # Skips a test on XLA if the condition is true.
|
1064 | 1082 | class skipXLAIf(skipIf):
|
1065 | 1083 |
|
@@ -1343,6 +1361,9 @@ def onlyMPS(fn):
|
1343 | 1361 | def onlyXPU(fn):
|
1344 | 1362 | return onlyOn('xpu')(fn)
|
1345 | 1363 |
|
| 1364 | +def onlyHPU(fn): |
| 1365 | + return onlyOn('hpu')(fn) |
| 1366 | + |
1346 | 1367 | def onlyPRIVATEUSE1(fn):
|
1347 | 1368 | device_type = torch._C._get_privateuse1_backend_name()
|
1348 | 1369 | device_mod = getattr(torch, device_type, None)
|
@@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn):
|
1401 | 1422 | def expectedFailureXLA(fn):
|
1402 | 1423 | return expectedFailure('xla')(fn)
|
1403 | 1424 |
|
| 1425 | +def expectedFailureHPU(fn): |
| 1426 | + return expectedFailure('hpu')(fn) |
| 1427 | + |
1404 | 1428 | # Skips a test on CPU if LAPACK is not available.
|
1405 | 1429 | def skipCPUIfNoLapack(fn):
|
1406 | 1430 | return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
|
@@ -1578,6 +1602,9 @@ def skipXLA(fn):
|
1578 | 1602 | def skipMPS(fn):
|
1579 | 1603 | return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
|
1580 | 1604 |
|
| 1605 | +def skipHPU(fn): |
| 1606 | + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) |
| 1607 | + |
1581 | 1608 | def skipPRIVATEUSE1(fn):
|
1582 | 1609 | return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
|
1583 | 1610 |
|
|
0 commit comments