Skip to content

Commit 61c5e43

Browse files
committed
Add some extra tests
1 parent 5f86df0 commit 61c5e43

File tree

2 files changed

+61
-23
lines changed

2 files changed

+61
-23
lines changed

xarray/core/variable.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,19 +2084,33 @@ def rolling_window(
20842084
var = self
20852085

20862086
if utils.is_scalar(dim):
2087-
for arg in [window, window_dim, center]:
2088-
assert utils.is_scalar(arg)
2087+
for name, arg in zip(
2088+
["window", "window_dim", "center"], [window, window_dim, center]
2089+
):
2090+
if not utils.is_scalar(arg):
2091+
raise ValueError(
2092+
f"Expected {name}={arg!r} to be a scalar like 'dim'."
2093+
)
20892094
dim = [dim]
2090-
window = [window]
2091-
window_dim = [window_dim]
2092-
center = [center]
2093-
else:
2094-
if len(dim) != len(window):
2095-
raise ValueError(
2096-
"'dim', 'window', 'window_dim', and 'center' must be the same length. "
2097-
f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r},"
2098-
f" and center={center!r}."
2099-
)
2095+
2096+
# dim is now a list
2097+
nroll = len(dim)
2098+
if utils.is_scalar(window):
2099+
window = [window] * nroll
2100+
if utils.is_scalar(window_dim):
2101+
window_dim = [window_dim] * nroll
2102+
if utils.is_scalar(center):
2103+
center = [center] * nroll
2104+
if (
2105+
len(dim) != len(window)
2106+
or len(dim) != len(window_dim)
2107+
or len(dim) != len(center)
2108+
):
2109+
raise ValueError(
2110+
"'dim', 'window', 'window_dim', and 'center' must be the same length. "
2111+
f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r},"
2112+
f" and center={center!r}."
2113+
)
21002114

21012115
pads = {}
21022116
for d, win, cent in zip(dim, window, center):

xarray/tests/test_variable.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,27 @@ def test_nd_rolling(self, center, dims):
948948
)
949949
assert_equal(actual, expected)
950950

951+
@pytest.mark.parametrize(
952+
("dim, window, window_dim, center"),
953+
[
954+
("x", [3, 3], "x_w", True),
955+
("x", 3, ("x_w", "x_w"), True),
956+
("x", 3, "x_w", [True, True]),
957+
],
958+
)
959+
def test_rolling_window_errors(self, dim, window, window_dim, center):
960+
x = self.cls(
961+
("x", "y", "z"),
962+
np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float),
963+
)
964+
with pytest.raises(ValueError):
965+
x.rolling_window(
966+
dim=dim,
967+
window=window,
968+
window_dim=window_dim,
969+
center=center,
970+
)
971+
951972

952973
class TestVariable(VariableSubclassobjects):
953974
cls = staticmethod(Variable)
@@ -2198,23 +2219,23 @@ def test_datetime64(self):
21982219

21992220
# These tests make use of multi-dimensional variables, which are not valid
22002221
# IndexVariable objects:
2201-
@pytest.mark.xfail
2222+
@pytest.mark.skip
22022223
def test_getitem_error(self):
22032224
super().test_getitem_error()
22042225

2205-
@pytest.mark.xfail
2226+
@pytest.mark.skip
22062227
def test_getitem_advanced(self):
22072228
super().test_getitem_advanced()
22082229

2209-
@pytest.mark.xfail
2230+
@pytest.mark.skip
22102231
def test_getitem_fancy(self):
22112232
super().test_getitem_fancy()
22122233

2213-
@pytest.mark.xfail
2234+
@pytest.mark.skip
22142235
def test_getitem_uint(self):
22152236
super().test_getitem_fancy()
22162237

2217-
@pytest.mark.xfail
2238+
@pytest.mark.skip
22182239
@pytest.mark.parametrize(
22192240
"mode",
22202241
[
@@ -2233,24 +2254,27 @@ def test_getitem_uint(self):
22332254
def test_pad(self, mode, xr_arg, np_arg):
22342255
super().test_pad(mode, xr_arg, np_arg)
22352256

2236-
@pytest.mark.xfail
2237-
@pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS)
2257+
@pytest.mark.skip
22382258
def test_pad_constant_values(self, xr_arg, np_arg):
22392259
super().test_pad_constant_values(xr_arg, np_arg)
22402260

2241-
@pytest.mark.xfail
2261+
@pytest.mark.skip
22422262
def test_rolling_window(self):
22432263
super().test_rolling_window()
22442264

2245-
@pytest.mark.xfail
2265+
@pytest.mark.skip
22462266
def test_rolling_1d(self):
22472267
super().test_rolling_1d()
22482268

2249-
@pytest.mark.xfail
2269+
@pytest.mark.skip
22502270
def test_nd_rolling(self):
22512271
super().test_nd_rolling()
22522272

2253-
@pytest.mark.xfail
2273+
@pytest.mark.skip
2274+
def test_rolling_window_errors(self):
2275+
super().test_rolling_window_errors()
2276+
2277+
@pytest.mark.skip
22542278
def test_coarsen_2d(self):
22552279
super().test_coarsen_2d()
22562280

0 commit comments

Comments
 (0)