From 369a9081f60c9a512cea49ee797c68350801389e Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 16:40:04 -0600 Subject: [PATCH 1/5] Support nanargmin, nanargmax --- flox/aggregations.py | 4 ++-- tests/conftest.py | 2 +- tests/test_core.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index e85c0699d..13b23fafe 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -421,7 +421,7 @@ def _pick_second(*x): chunk=("nanmax", "nanargmax"), # order is important combine=("max", "argmax"), reduction_type="argreduce", - fill_value=(dtypes.NINF, -1), + fill_value=(dtypes.NINF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), @@ -434,7 +434,7 @@ def _pick_second(*x): chunk=("nanmin", "nanargmin"), # order is important combine=("min", "argmin"), reduction_type="argreduce", - fill_value=(dtypes.INF, -1), + fill_value=(dtypes.INF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), diff --git a/tests/conftest.py b/tests/conftest.py index 8e5039d28..5c3bb81f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture(scope="module", params=["flox"]) +@pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) def engine(request): if request.param == "numba": try: diff --git a/tests/test_core.py b/tests/test_core.py index 7c152fd10..6973b81c7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -55,7 +55,7 @@ def dask_array_ones(*args): "nansum", "argmax", "nanfirst", - pytest.param("nanargmax", marks=(pytest.mark.skip,)), + "nanargmax", "prod", "nanprod", "mean", @@ -233,8 +233,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): # computing silences a bunch of dask warnings array_ = array.compute() if chunks is not None else array if "arg" in func and add_nan_by: - array_[..., nanmask] = np.nan - expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs) + func_ = f"nan{func}" if "nan" not in func else func + array[..., nanmask] = np.nan + expected = getattr(np, func_)(array, axis=-1, **kwargs) # elif func in ["first", "last"]: # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) elif func in ["nanfirst", "nanlast"]: From 3c36b46169478f6210f3feb1f5e043ddc114f550 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 11 May 2023 13:38:48 -0600 Subject: [PATCH 2/5] Fix test --- tests/test_core.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 6973b81c7..e2bd9cfc1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -69,7 +69,7 @@ def dask_array_ones(*args): "min", "nanmin", "argmin", - pytest.param("nanargmin", marks=(pytest.mark.skip,)), + "nanargmin", "any", "all", "nanlast", @@ -233,9 +233,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): # computing silences a bunch of dask warnings array_ = array.compute() if chunks is not None else array if "arg" in func and add_nan_by: + # NaNs are in by, but we can't call np.argmax([..., NaN, .. ]) + # That would return index of the NaN + # This way, we insert NaNs where there are NaNs in by, and + # call np.nanargmax func_ = f"nan{func}" if "nan" not in func else func - array[..., nanmask] = np.nan - expected = getattr(np, func_)(array, axis=-1, **kwargs) + array_[..., nanmask] = np.nan + expected = getattr(np, func_)(array_, axis=-1, **kwargs) # elif func in ["first", "last"]: # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) elif func in ["nanfirst", "nanlast"]: From 836860479ec2d2eb4875bb733670efdce1a3c133 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 11 May 2023 13:58:52 -0600 Subject: [PATCH 3/5] Add blockwise test --- tests/test_core.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index e2bd9cfc1..5c4db9248 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -264,6 +264,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): params = list(itertools.product(["map-reduce"], [True, False, None])) params.extend(itertools.product(["cohorts"], [False, None])) + if chunks == -1: + params.extend([("blockwise", None)]) + for method, reindex in params: call = partial( groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs @@ -274,11 +277,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): call() continue actual, *groups = call() - if "arg" not in func: - # make sure we use simple combine - assert any("simple-combine" in key for key in actual.dask.layers.keys()) - else: - assert any("grouped-combine" in key for key in actual.dask.layers.keys()) + if method != "blockwise": + if "arg" not in func: + # make sure we use simple combine + assert any("simple-combine" in key for key in actual.dask.layers.keys()) + else: + assert any("grouped-combine" in key for key in actual.dask.layers.keys()) for actual_group, expect in zip(groups, expected_groups): assert_equal(actual_group, expect, tolerance) if "arg" in func: From c282d0a8236bd8d2ed0b65f92127d81fc326316f Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 11 May 2023 14:09:50 -0600 Subject: [PATCH 4/5] Fix blockwise test --- flox/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 2444df8e3..d10eb3a64 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1323,8 +1323,10 @@ def dask_groupby_agg( by = dask.array.from_array(by, chunks=chunks) _, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :]) - # preprocess the array: for argreductions, this zips the index together with the array block - if agg.preprocess: + # preprocess the array: + # - for argreductions, this zips the index together with the array block + # - not necessary for blockwise, for argreductions + if agg.preprocess and method != "blockwise": array = agg.preprocess(array, axis=axis) # 1. We first apply the groupby-reduction blockwise to generate "intermediates" From ec2cf3f28d8149d47a5b3ffa1b61bf1bfa38e3d3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 11 May 2023 14:11:03 -0600 Subject: [PATCH 5/5] Apply suggestions from code review --- flox/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index d10eb3a64..57ea4556f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1325,7 +1325,8 @@ def dask_groupby_agg( # preprocess the array: # - for argreductions, this zips the index together with the array block - # - not necessary for blockwise, for argreductions + # - not necessary for blockwise with argreductions + # - if this is needed later, we can fix this then if agg.preprocess and method != "blockwise": array = agg.preprocess(array, axis=axis)