diff --git a/flox/xarray.py b/flox/xarray.py index 9c8fe6108..29b023a0e 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -194,6 +194,7 @@ def xarray_reduce( if skipna is not None and isinstance(func, Aggregation): raise ValueError("skipna must be None when func is an Aggregation.") + nby = len(by) for b in by: if isinstance(b, xr.DataArray) and b.name is None: raise ValueError("Cannot group by unnamed DataArrays.") @@ -203,11 +204,11 @@ def xarray_reduce( keep_attrs = True if isinstance(isbin, bool): - isbin = (isbin,) * len(by) + isbin = (isbin,) * nby if expected_groups is None: - expected_groups = (None,) * len(by) + expected_groups = (None,) * nby if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list - if len(by) == 1: + if nby == 1: expected_groups = (expected_groups,) else: raise ValueError("Needs better message.") @@ -239,6 +240,8 @@ def xarray_reduce( ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) if dim is Ellipsis: + if nby > 1: + raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") dim = tuple(obj.dims) if by[0].name in ds.dims and not isbin[0]: dim = tuple(d for d in dim if d != by[0].name) @@ -351,7 +354,7 @@ def wrapper(array, *by, func, skipna, **kwargs): missing_dim[k] = v input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims) - input_core_dims += [input_core_dims[-1]] * (len(by) - 1) + input_core_dims += [input_core_dims[-1]] * (nby - 1) actual = xr.apply_ufunc( wrapper, @@ -409,7 +412,7 @@ def wrapper(array, *by, func, skipna, **kwargs): if unindexed_dims: actual = actual.drop_vars(unindexed_dims) - if len(by) == 1: + if nby == 1: for var in actual: if isinstance(obj, xr.DataArray): template = obj diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 90a2d50c4..6669830b5 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -159,6 +159,9 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine): actual = xarray_reduce(da, "labels", "labels2", **kwargs) xr.testing.assert_identical(expected, actual) + with pytest.raises(NotImplementedError): + xarray_reduce(da, "labels", "labels2", dim=..., **kwargs) + @requires_dask def test_dask_groupers_error():