Skip to content

Commit 684c457

Browse files
committed
update fft tests
1 parent 8c50aff commit 684c457

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

dpnp/tests/test_fft.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,13 @@ def test_basic(self, dtype, n, norm):
563563

564564
result = dpnp.fft.hfft(ia, n=n, norm=norm)
565565
expected = numpy.fft.hfft(a, n=n, norm=norm)
566-
# check_only_type_kind=True since NumPy always returns float64
567-
# but dpnp return float32 if input is float32
568-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
566+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
567+
# and being used with Intel NumPy >= 2.0.0
568+
flag = True
569+
# flag = True if numpy_version() < "2.0.0" else False
570+
assert_dtype_allclose(
571+
result, expected, factor=24, check_only_type_kind=flag
572+
)
569573

570574
@pytest.mark.parametrize(
571575
"dtype", get_all_dtypes(no_none=True, no_complex=True)
@@ -579,7 +583,7 @@ def test_inverse(self, dtype, n, norm):
579583
result = dpnp.fft.ihfft(ia, n=n, norm=norm)
580584
expected = numpy.fft.ihfft(a, n=n, norm=norm)
581585
flag = True if numpy_version() < "2.0.0" else False
582-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
586+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
583587

584588
def test_error(self):
585589
a = dpnp.ones(11)
@@ -600,14 +604,16 @@ class TestIrfft:
600604
@pytest.mark.parametrize("n", [None, 5, 18])
601605
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
602606
def test_basic(self, dtype, n, norm):
603-
a = generate_random_numpy_array(11)
607+
a = generate_random_numpy_array(11, dtype=dtype)
604608
ia = dpnp.array(a)
605609

606610
result = dpnp.fft.irfft(ia, n=n, norm=norm)
607611
expected = numpy.fft.irfft(a, n=n, norm=norm)
608-
# check_only_type_kind=True since NumPy always returns float64
609-
# but dpnp return float32 if input is float32
610-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
612+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
613+
# and being used with Intel NumPy >= 2.0.0
614+
flag = True
615+
# flag = True if numpy_version() < "2.0.0" else False
616+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
611617

612618
@pytest.mark.parametrize("dtype", get_complex_dtypes())
613619
@pytest.mark.parametrize("n", [None, 5, 8])
@@ -771,8 +777,11 @@ def test_float16(self):
771777

772778
expected = numpy.fft.rfft(a)
773779
result = dpnp.fft.rfft(ia)
774-
# check_only_type_kind=True since Intel NumPy returns complex128
775-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
780+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
781+
# and being used with Intel NumPy >= 2.0.0
782+
flag = True
783+
# flag = True if numpy_version() < "2.0.0" else False
784+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
776785

777786
@testing.with_requires("numpy>=2.0.0")
778787
@pytest.mark.parametrize("xp", [numpy, dpnp])
@@ -954,7 +963,8 @@ def test_1d_array(self):
954963

955964
result = dpnp.fft.irfftn(ia)
956965
expected = numpy.fft.irfftn(a)
957-
# TODO: change to the commented line when mkl_fft-gh-180 is merged
966+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
967+
# and being used with Intel NumPy >= 2.0.0
958968
flag = True
959969
# flag = True if numpy_version() < "2.0.0" else False
960970
assert_dtype_allclose(result, expected, check_only_type_kind=flag)

0 commit comments

Comments
 (0)