Skip to content

Commit b42b0c7

Browse files
committed
Squashed commit of the following:
commit cfe87e0 Merge: 1c751a6 1f81338 Author: Deepak Cherian <[email protected]> Date: Thu Oct 17 01:13:49 2019 +0000 Merge branch 'master' into fix/groupby-nan commit 1c751a6 Author: dcherian <[email protected]> Date: Wed Oct 16 19:09:19 2019 -0600 whats-new commit 71df146 Author: dcherian <[email protected]> Date: Wed Oct 16 19:03:22 2019 -0600 Add NaTs commit 1f81338 Author: keewis <[email protected]> Date: Wed Oct 16 20:54:27 2019 +0200 Fixes to the resample docs (pydata#3400) * add a missing newline to make sphinx detect the code block * update the link to the pandas documentation * explicitly state that this only works with datetime dimensions * also put the datetime dim requirement into the function description * add Series.resample and DataFrame.resample as reference * add the changes to whats-new.rst * move references to the bottom of the docstring commit 5bf94a8 Author: dcherian <[email protected]> Date: Mon Oct 14 13:25:32 2019 -0600 Drop nans in grouped variable.
1 parent d8b48b8 commit b42b0c7

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

xarray/core/groupby.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,13 @@ def __init__(
350350
group_indices = [slice(i, i + 1) for i in group_indices]
351351
unique_coord = group
352352
else:
353+
if group.isnull().any():
354+
# drop any NaN valued groups.
355+
# also drop obj values where group was NaN
356+
# Use where instead of reindex to account for duplicate coordinate labels.
357+
obj = obj.where(group.notnull(), drop=True)
358+
group = group.dropna(group_dim)
359+
353360
# look through group to find the unique values
354361
unique_values, group_indices = unique_value_groups(
355362
safe_cast_to_index(group), sort=(bins is None)

xarray/tests/test_groupby.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray as xr
66
from xarray.core.groupby import _consolidate_slices
77

8-
from . import assert_identical, raises_regex
8+
from . import assert_equal, assert_identical, raises_regex
99

1010

1111
def test_consolidate_slices():
@@ -40,14 +40,14 @@ def test_multi_index_groupby_apply():
4040
{"foo": (("x", "y"), np.random.randn(3, 4))},
4141
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4]},
4242
)
43-
doubled = 2 * ds
44-
group_doubled = (
43+
expected = 2 * ds
44+
actual = (
4545
ds.stack(space=["x", "y"])
4646
.groupby("space")
4747
.apply(lambda x: 2 * x)
4848
.unstack("space")
4949
)
50-
assert doubled.equals(group_doubled)
50+
assert_equal(expected, actual)
5151

5252

5353
def test_multi_index_groupby_sum():
@@ -58,7 +58,7 @@ def test_multi_index_groupby_sum():
5858
)
5959
expected = ds.sum("z")
6060
actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space")
61-
assert expected.equals(actual)
61+
assert_equal(expected, actual)
6262

6363

6464
def test_groupby_da_datetime():
@@ -78,15 +78,15 @@ def test_groupby_da_datetime():
7878
expected = xr.DataArray(
7979
[3, 7], coords=dict(reference_date=reference_dates), dims="reference_date"
8080
)
81-
assert actual.equals(expected)
81+
assert_equal(expected, actual)
8282

8383

8484
def test_groupby_duplicate_coordinate_labels():
8585
# fix for http://stackoverflow.com/questions/38065129
8686
array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
8787
expected = xr.DataArray([3, 3], [("x", [1, 2])])
8888
actual = array.groupby("x").sum()
89-
assert expected.equals(actual)
89+
assert_equal(expected, actual)
9090

9191

9292
def test_groupby_input_mutation():
@@ -255,6 +255,72 @@ def test_groupby_repr_datetime(obj):
255255
assert actual == expected
256256

257257

258+
def test_groupby_drops_nans():
259+
# GH2383
260+
# nan in 2D data variable (requires stacking)
261+
ds = xr.Dataset(
262+
{
263+
"variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))),
264+
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
265+
},
266+
coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)},
267+
)
268+
269+
ds["id"].values[0, 0] = np.nan
270+
ds["id"].values[3, 0] = np.nan
271+
ds["id"].values[-1, -1] = np.nan
272+
273+
grouped = ds.groupby(ds.id)
274+
275+
# non reduction operation
276+
expected = ds.copy()
277+
expected.variable.values[0, 0, :] = np.nan
278+
expected.variable.values[-1, -1, :] = np.nan
279+
expected.variable.values[3, 0, :] = np.nan
280+
actual = grouped.apply(lambda x: x).transpose(*ds.variable.dims)
281+
assert_identical(actual, expected)
282+
283+
# reduction along grouped dimension
284+
actual = grouped.mean()
285+
stacked = ds.stack({"xy": ["lat", "lon"]})
286+
expected = (
287+
stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset()
288+
)
289+
expected["id"] = stacked.id.values
290+
assert_identical(actual, expected.dropna("id").transpose(*actual.dims))
291+
292+
# reduction operation along a different dimension
293+
actual = grouped.mean("time")
294+
expected = ds.mean("time").where(ds.id.notnull())
295+
assert_identical(actual, expected)
296+
297+
# NaN in non-dimensional coordinate
298+
array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])])
299+
array["x1"] = ("x", [1, 1, np.nan])
300+
expected = xr.DataArray(3, [("x1", [1])])
301+
actual = array.groupby("x1").sum()
302+
assert_equal(expected, actual)
303+
304+
# NaT in non-dimensional coordinate
305+
array["t"] = (
306+
"x",
307+
[
308+
np.datetime64("2001-01-01"),
309+
np.datetime64("2001-01-01"),
310+
np.datetime64("NaT"),
311+
],
312+
)
313+
expected = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])])
314+
actual = array.groupby("t").sum()
315+
assert_equal(expected, actual)
316+
317+
# test for repeated coordinate labels
318+
array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])])
319+
expected = xr.DataArray([3, 3], [("x", [1, 2])])
320+
actual = array.groupby("x").sum()
321+
assert_equal(expected, actual)
322+
323+
258324
def test_groupby_grouping_errors():
259325
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
260326
with raises_regex(ValueError, "None of the data falls within bins with edges"):

0 commit comments

Comments
 (0)