Skip to content

Raise error if multiple by's are used with Ellipsis #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def xarray_reduce(
if skipna is not None and isinstance(func, Aggregation):
raise ValueError("skipna must be None when func is an Aggregation.")

nby = len(by)
for b in by:
if isinstance(b, xr.DataArray) and b.name is None:
raise ValueError("Cannot group by unnamed DataArrays.")
Expand All @@ -203,11 +204,11 @@ def xarray_reduce(
keep_attrs = True

if isinstance(isbin, bool):
isbin = (isbin,) * len(by)
isbin = (isbin,) * nby
if expected_groups is None:
expected_groups = (None,) * len(by)
expected_groups = (None,) * nby
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
if len(by) == 1:
if nby == 1:
expected_groups = (expected_groups,)
else:
raise ValueError("Needs better message.")
Expand Down Expand Up @@ -239,6 +240,8 @@ def xarray_reduce(
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])

if dim is Ellipsis:
if nby > 1:
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
dim = tuple(obj.dims)
if by[0].name in ds.dims and not isbin[0]:
dim = tuple(d for d in dim if d != by[0].name)
Expand Down Expand Up @@ -351,7 +354,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (len(by) - 1)
input_core_dims += [input_core_dims[-1]] * (nby - 1)

actual = xr.apply_ufunc(
wrapper,
Expand Down Expand Up @@ -409,7 +412,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
if unindexed_dims:
actual = actual.drop_vars(unindexed_dims)

if len(by) == 1:
if nby == 1:
for var in actual:
if isinstance(obj, xr.DataArray):
template = obj
Expand Down
3 changes: 3 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
actual = xarray_reduce(da, "labels", "labels2", **kwargs)
xr.testing.assert_identical(expected, actual)

with pytest.raises(NotImplementedError):
xarray_reduce(da, "labels", "labels2", dim=..., **kwargs)


@requires_dask
def test_dask_groupers_error():
Expand Down