Skip to content

Commit c1358e7

Browse files
authored
Recognize grouping by IntervalIndex as binning (#205)
* Recognize grouping by IntervalIndex as binning * Update flox/core.py * More tests
1 parent cd6eeb5 commit c1358e7

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

flox/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,10 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
15591559
return (None,) * nby
15601560

15611561
if nby == 1 and not isinstance(expected_groups, tuple):
1562-
return (np.asarray(expected_groups),)
1562+
if isinstance(expected_groups, pd.Index):
1563+
return (expected_groups,)
1564+
else:
1565+
return (np.asarray(expected_groups),)
15631566

15641567
if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
15651568
raise ValueError(
@@ -1734,9 +1737,11 @@ def groupby_reduce(
17341737
# (pd.IntervalIndex or not)
17351738
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
17361739

1740+
is_binning = any([isinstance(e, pd.IntervalIndex) for e in expected_groups])
1741+
17371742
# TODO: could restrict this to dask-only
17381743
factorize_early = (nby > 1) or (
1739-
any(isbins) and method == "cohorts" and is_duck_dask_array(array)
1744+
is_binning and method == "cohorts" and is_duck_dask_array(array)
17401745
)
17411746
if factorize_early:
17421747
bys, final_groups, grp_shape = _factorize_multiple(

flox/xarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def xarray_reduce(
313313
group_names: tuple[Any, ...] = ()
314314
group_sizes: dict[Any, int] = {}
315315
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)):
316-
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
316+
group_name = (
317+
f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name
318+
)
317319
group_names += (group_name,)
318320

319321
if isbin_ and isinstance(expect, int):

tests/test_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -644,10 +644,17 @@ def test_npg_nanarg_bug(func):
644644
assert_equal(actual, expected)
645645

646646

647+
@pytest.mark.parametrize(
648+
"kwargs",
649+
(
650+
dict(expected_groups=np.array([1, 2, 4, 5]), isbin=True),
651+
dict(expected_groups=pd.IntervalIndex.from_breaks([1, 2, 4, 5])),
652+
),
653+
)
647654
@pytest.mark.parametrize("method", ["cohorts", "map-reduce"])
648655
@pytest.mark.parametrize("chunk_labels", [False, True])
649656
@pytest.mark.parametrize("chunks", ((), (1,), (2,)))
650-
def test_groupby_bins(chunk_labels, chunks, engine, method) -> None:
657+
def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
651658
array = [1, 1, 1, 1, 1, 1]
652659
labels = [0.2, 1.5, 1.9, 2, 3, 20]
653660

@@ -663,14 +670,7 @@ def test_groupby_bins(chunk_labels, chunks, engine, method) -> None:
663670

664671
with raise_if_dask_computes():
665672
actual, groups = groupby_reduce(
666-
array,
667-
labels,
668-
func="count",
669-
expected_groups=np.array([1, 2, 4, 5]),
670-
isbin=True,
671-
fill_value=0,
672-
engine=engine,
673-
method=method,
673+
array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs
674674
)
675675
expected = np.array([3, 1, 0], dtype=np.intp)
676676
for left, right in zip(groups, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()):

tests/test_xarray.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ def test_datetime_array_reduce(use_cftime, func, engine):
457457

458458

459459
@requires_dask
460-
def test_groupby_bins_indexed_coordinate():
460+
@pytest.mark.parametrize("method", ["cohorts", "map-reduce"])
461+
def test_groupby_bins_indexed_coordinate(method):
461462
ds = (
462463
xr.tutorial.open_dataset("air_temperature")
463464
.isel(time=slice(100))
@@ -472,7 +473,17 @@ def test_groupby_bins_indexed_coordinate():
472473
expected_groups=([40, 50, 60, 70],),
473474
isbin=(True,),
474475
func="mean",
475-
method="split-reduce",
476+
method=method,
477+
)
478+
xr.testing.assert_allclose(expected, actual)
479+
480+
actual = xarray_reduce(
481+
ds,
482+
ds.lat,
483+
dim=ds.air.dims,
484+
expected_groups=pd.IntervalIndex.from_breaks([40, 50, 60, 70]),
485+
func="mean",
486+
method=method,
476487
)
477488
xr.testing.assert_allclose(expected, actual)
478489

0 commit comments

Comments
 (0)