Skip to content

Commit c0cb300

Browse files
committed
Nicer test
1 parent 592a2fa commit c0cb300

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/test_xarray.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,16 +487,21 @@ def test_mixed_grouping(chunk):
487487
assert (r.sel(v1=[3, 4, 5]) == 0).all().data
488488

489489

490+
@pytest.mark.parametrize("add_nan", [True, False])
490491
@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
491492
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
492493
@pytest.mark.parametrize("chunk", (True, False))
493-
def test_dtype(chunk, dtype, dtype_out, engine):
494+
def test_dtype(add_nan, chunk, dtype, dtype_out, engine):
494495
if chunk and not has_dask:
495496
pytest.skip()
496497

497498
xp = dask.array if chunk else np
498499
data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))
499500

501+
if add_nan:
502+
data[1, ...] = np.nan
503+
data[0, [0, 2]] = np.nan
504+
500505
arr = xr.DataArray(
501506
data,
502507
dims=("x", "t"),
@@ -511,11 +516,11 @@ def test_dtype(chunk, dtype, dtype_out, engine):
511516

512517
assert actual.dtype == np.dtype("float64")
513518
assert actual.compute().dtype == np.dtype("float64")
514-
assert_equal(expected, actual)
519+
xr.testing.assert_allclose(expected, actual)
515520

516521
actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs)
517522
expected = arr.to_dataset().groupby("labels").mean(dtype="float64")
518523

519524
assert actual.arr.dtype == np.dtype("float64")
520525
assert actual.compute().arr.dtype == np.dtype("float64")
521-
assert_equal(expected, actual.transpose("labels", ...))
526+
xr.testing.assert_allclose(expected, actual.transpose("labels", ...))

0 commit comments

Comments
 (0)