|
1 | 1 | import functools
|
2 | 2 |
|
| 3 | +import dpctl |
3 | 4 | import dpctl.tensor as dpt
|
4 | 5 | import numpy
|
5 | 6 | import pytest
|
6 | 7 | from dpctl.tensor._numpy_helper import AxisError
|
| 8 | +from dpctl.utils import ExecutionPlacementError |
7 | 9 | from numpy.testing import (
|
8 | 10 | assert_,
|
9 | 11 | assert_array_equal,
|
@@ -1307,3 +1309,51 @@ def test_error(self):
|
1307 | 1309 | dpnp.select([x0], [x1], default=x1)
|
1308 | 1310 | with pytest.raises(TypeError):
|
1309 | 1311 | dpnp.select([x1], [x1])
|
| 1312 | + |
| 1313 | + |
| 1314 | +def test_compress_basic(): |
| 1315 | + a = dpnp.arange(16).reshape(4, 4) |
| 1316 | + condition = dpnp.asarray([True, False, True]) |
| 1317 | + r = dpnp.compress(condition, a, axis=0) |
| 1318 | + assert_array_equal(r[0], a[0]) |
| 1319 | + assert_array_equal(r[1], a[2]) |
| 1320 | + |
| 1321 | + |
| 1322 | +@pytest.mark.parametrize("dtype", get_all_dtypes()) |
| 1323 | +def test_compress_condition_all_dtypes(dtype): |
| 1324 | + a = dpnp.arange(10, dtype="i4") |
| 1325 | + condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5) |
| 1326 | + r = dpnp.compress(condition, a) |
| 1327 | + assert_array_equal(r, a[1::2]) |
| 1328 | + |
| 1329 | + |
| 1330 | +def test_compress_invalid_out_errors(): |
| 1331 | + q1 = dpctl.SyclQueue() |
| 1332 | + q2 = dpctl.SyclQueue() |
| 1333 | + a = dpnp.ones(10, dtype="i4", sycl_queue=q1) |
| 1334 | + condition = dpnp.asarray([True], sycl_queue=q1) |
| 1335 | + out_bad_shape = dpnp.empty_like(a) |
| 1336 | + with pytest.raises(ValueError): |
| 1337 | + dpnp.compress(condition, a, out=out_bad_shape) |
| 1338 | + out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2) |
| 1339 | + with pytest.raises(ExecutionPlacementError): |
| 1340 | + dpnp.compress(condition, a, out=out_bad_queue) |
| 1341 | + out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1) |
| 1342 | + with pytest.raises(ValueError): |
| 1343 | + dpnp.compress(condition, a, out=out_bad_dt) |
| 1344 | + out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1) |
| 1345 | + out_read_only.flags.writable = False |
| 1346 | + with pytest.raises(ValueError): |
| 1347 | + dpnp.compress(condition, a, out=out_read_only) |
| 1348 | + |
| 1349 | + |
| 1350 | +def test_compress_empty_axis(): |
| 1351 | + a = dpnp.ones((10, 0, 5), dtype="i4") |
| 1352 | + condition = [True, False, True] |
| 1353 | + r = dpnp.compress(condition, a, axis=0) |
| 1354 | + assert r.shape == (2, 0, 5) |
| 1355 | + # empty take from empty axis is permitted |
| 1356 | + assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5) |
| 1357 | + # non-empty take from empty axis raises IndexError |
| 1358 | + with pytest.raises(IndexError): |
| 1359 | + dpnp.compress(condition, a, axis=1) |
0 commit comments