Skip to content

Commit cfc664e

Browse files
committed
Add tests for compress
1 parent 7ca0d09 commit cfc664e

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

dpnp/tests/test_indexing.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import functools
22

3+
import dpctl
34
import dpctl.tensor as dpt
45
import numpy
56
import pytest
67
from dpctl.tensor._numpy_helper import AxisError
8+
from dpctl.utils import ExecutionPlacementError
79
from numpy.testing import (
810
assert_,
911
assert_array_equal,
@@ -1307,3 +1309,51 @@ def test_error(self):
13071309
dpnp.select([x0], [x1], default=x1)
13081310
with pytest.raises(TypeError):
13091311
dpnp.select([x1], [x1])
1312+
1313+
1314+
def test_compress_basic():
1315+
a = dpnp.arange(16).reshape(4, 4)
1316+
condition = dpnp.asarray([True, False, True])
1317+
r = dpnp.compress(condition, a, axis=0)
1318+
assert_array_equal(r[0], a[0])
1319+
assert_array_equal(r[1], a[2])
1320+
1321+
1322+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1323+
def test_compress_condition_all_dtypes(dtype):
1324+
a = dpnp.arange(10, dtype="i4")
1325+
condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1326+
r = dpnp.compress(condition, a)
1327+
assert_array_equal(r, a[1::2])
1328+
1329+
1330+
def test_compress_invalid_out_errors():
1331+
q1 = dpctl.SyclQueue()
1332+
q2 = dpctl.SyclQueue()
1333+
a = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1334+
condition = dpnp.asarray([True], sycl_queue=q1)
1335+
out_bad_shape = dpnp.empty_like(a)
1336+
with pytest.raises(ValueError):
1337+
dpnp.compress(condition, a, out=out_bad_shape)
1338+
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
1339+
with pytest.raises(ExecutionPlacementError):
1340+
dpnp.compress(condition, a, out=out_bad_queue)
1341+
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1342+
with pytest.raises(ValueError):
1343+
dpnp.compress(condition, a, out=out_bad_dt)
1344+
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
1345+
out_read_only.flags.writable = False
1346+
with pytest.raises(ValueError):
1347+
dpnp.compress(condition, a, out=out_read_only)
1348+
1349+
1350+
def test_compress_empty_axis():
1351+
a = dpnp.ones((10, 0, 5), dtype="i4")
1352+
condition = [True, False, True]
1353+
r = dpnp.compress(condition, a, axis=0)
1354+
assert r.shape == (2, 0, 5)
1355+
# empty take from empty axis is permitted
1356+
assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5)
1357+
# non-empty take from empty axis raises IndexError
1358+
with pytest.raises(IndexError):
1359+
dpnp.compress(condition, a, axis=1)

0 commit comments

Comments
 (0)