Skip to content

Commit c3a75a0

Browse files
committed
Fix iteration on rolling windows to handle no padding
Also changes the object returned during iteration, so that it is the same (modulo transposition) as the view created during rolling_window.construct() along the rolling axis.
1 parent 462e0e1 commit c3a75a0

File tree

2 files changed

+59
-24
lines changed

2 files changed

+59
-24
lines changed

xarray/core/rolling.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,47 @@ def __init__(
263263
def __iter__(self):
264264
if len(self.dim) > 1:
265265
raise ValueError("__iter__ is only supported for 1d-rolling")
266-
stops = np.arange(1, len(self.window_labels) + 1)
267-
starts = stops - int(self.window[0])
268-
starts[: int(self.window[0])] = 0
269-
for (label, start, stop) in zip(self.window_labels, starts, stops):
270-
window = self.obj.isel(**{self.dim[0]: slice(start, stop)})
266+
dim = self.dim[0]
267+
center = self.center[0]
268+
pad = self.pad[0]
269+
window = self.window[0]
270+
center_offset = window // 2 if center else 0
271+
272+
pads = utils.get_pads(self.dim, self.window, self.center, self.pad)
273+
start_pad, end_pad = pads[dim]
274+
275+
# Select the proper subset of labels, based on whether or not to center and/or pad
276+
first_label_idx = 0 if pad else center_offset if center else window - 1
277+
last_label_idx = (
278+
len(self.obj[dim])
279+
if pad or not center
280+
else len(self.obj[dim]) - center_offset
281+
)
282+
283+
labels = (
284+
self.obj[dim][slice(first_label_idx, last_label_idx)]
285+
if self.obj[dim].coords
286+
else np.arange(last_label_idx - first_label_idx)
287+
)
288+
289+
padded_obj = self.obj.pad(pads, mode="constant", constant_values=dtypes.NA)
290+
291+
if pad and not center:
292+
first_stop = 1
293+
last_stop = len(self.obj[dim])
294+
elif pad and center:
295+
first_stop = end_pad + 1
296+
last_stop = len(self.obj[dim]) + end_pad
297+
elif not pad:
298+
first_stop = window
299+
last_stop = len(self.obj[dim])
300+
301+
# These are indicies into the padded array, so we need to add start_pad
302+
stops = np.arange(first_stop, last_stop + 1) + start_pad
303+
starts = stops - window
304+
305+
for (label, start, stop) in zip(labels, starts, stops):
306+
window = padded_obj.isel({self.dim[0]: slice(start, stop)})
271307

272308
counts = window.count(dim=self.dim[0])
273309
window = window.where(counts >= self.min_periods)
@@ -486,15 +522,17 @@ def _counts(self, keep_attrs):
486522
# array is faster to be reduced than object array.
487523
# The use of skipna==False is also faster since it does not need to
488524
# copy the strided array.
525+
output_dim_coords = self._get_rolling_dim_coords()
489526
counts = (
490527
self.obj.notnull(keep_attrs=keep_attrs)
491528
.rolling(
492529
center={d: self.center[i] for i, d in enumerate(self.dim)},
530+
pad={p: self.pad[i] for i, p in enumerate(self.pad)},
493531
**{d: w for d, w in zip(self.dim, self.window)},
494532
)
495533
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
496534
.sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs)
497-
)
535+
).sel(output_dim_coords)
498536
return counts
499537

500538
def _bottleneck_reduce(self, func, keep_attrs, **kwargs):

xarray/tests/test_dataarray.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6506,23 +6506,16 @@ def test_isin(da):
65066506
@pytest.mark.parametrize("da", (1, 2), indirect=True)
65076507
@pytest.mark.parametrize("center", (True, False, None))
65086508
@pytest.mark.parametrize("pad", (True, False, None))
6509-
def test_rolling_iter(da, center, pad):
6510-
rolling_obj = da.rolling(time=7, center=center, pad=pad)
6509+
@pytest.mark.parametrize("min_periods", (1, 6, None))
6510+
@pytest.mark.parametrize("window", (6, 7))
6511+
def test_rolling_iter(da, center, pad, min_periods, window):
6512+
rolling_obj = da.rolling(
6513+
time=window, center=center, pad=pad, min_periods=min_periods
6514+
)
65116515
rolling_obj_mean = rolling_obj.mean()
65126516

6513-
if pad:
6514-
expected_times = da["time"]
6515-
else:
6516-
if center:
6517-
expected_times = da["time"][slice(3, -3)]
6518-
else:
6519-
expected_times = da["time"][slice(6, None)]
6520-
6521-
assert len(rolling_obj.window_labels) == len(expected_times)
6522-
assert_identical(rolling_obj.window_labels, expected_times)
6523-
65246517
for i, (label, window_da) in enumerate(rolling_obj):
6525-
assert label == da["time"].isel(time=i)
6518+
assert label == rolling_obj_mean["time"].isel(time=i)
65266519

65276520
actual = rolling_obj_mean.isel(time=i)
65286521
expected = window_da.mean("time")
@@ -6531,10 +6524,14 @@ def test_rolling_iter(da, center, pad):
65316524
# as well as the closeness of the values.
65326525
assert_array_equal(actual.isnull(), expected.isnull())
65336526
if (~actual.isnull()).sum() > 0:
6534-
np.allclose(
6535-
actual.values[actual.values.nonzero()],
6536-
expected.values[expected.values.nonzero()],
6537-
)
6527+
if actual.ndim == 0:
6528+
actual_values = actual.values
6529+
expected_values = expected.values
6530+
else:
6531+
actual_values = actual.values[actual.values.nonzero()]
6532+
expected_values = expected.values[expected.values.nonzero()]
6533+
6534+
assert np.allclose(actual_values, expected_values)
65386535

65396536

65406537
@pytest.mark.parametrize("da", (1,), indirect=True)

0 commit comments

Comments
 (0)