Skip to content

Commit 145f25f

Browse files
authored
More consistency checks (#2859)
* Enable additional invariant checks in xarray's test suite * Tweak internal consistency checks * Various small fixes * Always use internal invariant checks * Fix coordinates type from DataArray.transpose()
1 parent 4c758e6 commit 145f25f

File tree

7 files changed

+143
-52
lines changed

7 files changed

+143
-52
lines changed

xarray/core/coordinates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _update_coords(self, coords):
193193

194194
self._data._variables = variables
195195
self._data._coord_names.update(new_coord_names)
196-
self._data._dims = dict(dims)
196+
self._data._dims = dims
197197
self._data._indexes = None
198198

199199
def __delitem__(self, key):

xarray/core/dataarray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import warnings
44
from collections import OrderedDict
5+
from typing import Any
56

67
import numpy as np
78
import pandas as pd
@@ -67,7 +68,7 @@ def _infer_coords_and_dims(shape, coords, dims):
6768
for dim, coord in zip(dims, coords):
6869
var = as_variable(coord, name=dim)
6970
var.dims = (dim,)
70-
new_coords[dim] = var
71+
new_coords[dim] = var.to_index_variable()
7172

7273
sizes = dict(zip(dims, shape))
7374
for k, v in new_coords.items():
@@ -1442,7 +1443,7 @@ def transpose(self, *dims, transpose_coords=None) -> 'DataArray':
14421443

14431444
variable = self.variable.transpose(*dims)
14441445
if transpose_coords:
1445-
coords = {}
1446+
coords = OrderedDict() # type: OrderedDict[Any, Variable]
14461447
for name, coord in self.coords.items():
14471448
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
14481449
coords[name] = coord.variable.transpose(*coord_dims)

xarray/core/dataset.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def calculate_dimensions(variables):
100100
Returns dictionary mapping from dimension names to sizes. Raises ValueError
101101
if any of the dimension sizes conflict.
102102
"""
103-
dims = OrderedDict()
103+
dims = {}
104104
last_used = {}
105105
scalar_vars = set(k for k, v in variables.items() if not v.dims)
106106
for k, var in variables.items():
@@ -692,7 +692,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None,
692692

693693
@classmethod
694694
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
695-
dims = dict(calculate_dimensions(variables))
695+
dims = calculate_dimensions(variables)
696696
return cls._construct_direct(variables, coord_names, dims, attrs)
697697

698698
# TODO(shoyer): renable type checking on this signature when pytype has a
@@ -753,18 +753,20 @@ def _replace_with_new_dims( # type: ignore
753753
coord_names: set = None,
754754
attrs: 'Optional[OrderedDict]' = __default,
755755
indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default,
756+
encoding: Optional[dict] = __default,
756757
inplace: bool = False,
757758
) -> T:
758759
"""Replace variables with recalculated dimensions."""
759-
dims = dict(calculate_dimensions(variables))
760+
dims = calculate_dimensions(variables)
760761
return self._replace(
761-
variables, coord_names, dims, attrs, indexes, inplace=inplace)
762+
variables, coord_names, dims, attrs, indexes, encoding,
763+
inplace=inplace)
762764

763765
def _replace_vars_and_dims( # type: ignore
764766
self: T,
765767
variables: 'OrderedDict[Any, Variable]' = None,
766768
coord_names: set = None,
767-
dims: 'OrderedDict[Any, int]' = None,
769+
dims: Dict[Any, int] = None,
768770
attrs: 'Optional[OrderedDict]' = __default,
769771
inplace: bool = False,
770772
) -> T:
@@ -1080,6 +1082,7 @@ def __delitem__(self, key):
10801082
"""
10811083
del self._variables[key]
10821084
self._coord_names.discard(key)
1085+
self._dims = calculate_dimensions(self._variables)
10831086

10841087
# mutable objects should not be hashable
10851088
# https://github.com/python/mypy/issues/4266
@@ -2469,7 +2472,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
24692472
else:
24702473
# If dims includes a label of a non-dimension coordinate,
24712474
# it will be promoted to a 1D coordinate with a single value.
2472-
variables[k] = v.set_dims(k)
2475+
variables[k] = v.set_dims(k).to_index_variable()
24732476

24742477
new_dims = self._dims.copy()
24752478
new_dims.update(dim)
@@ -3556,12 +3559,15 @@ def from_dict(cls, d):
35563559
def _unary_op(f, keep_attrs=False):
35573560
@functools.wraps(f)
35583561
def func(self, *args, **kwargs):
3559-
ds = self.coords.to_dataset()
3560-
for k in self.data_vars:
3561-
ds._variables[k] = f(self._variables[k], *args, **kwargs)
3562-
if keep_attrs:
3563-
ds._attrs = self._attrs
3564-
return ds
3562+
variables = OrderedDict()
3563+
for k, v in self._variables.items():
3564+
if k in self._coord_names:
3565+
variables[k] = v
3566+
else:
3567+
variables[k] = f(v, *args, **kwargs)
3568+
attrs = self._attrs if keep_attrs else None
3569+
return self._replace_with_new_dims(
3570+
variables, attrs=attrs, encoding=None)
35653571

35663572
return func
35673573

xarray/core/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def merge_core(objs,
473473
'coordinates or not in the merged result: %s'
474474
% ambiguous_coords)
475475

476-
return variables, coord_names, dict(dims)
476+
return variables, coord_names, dims
477477

478478

479479
def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA):

xarray/testing.py

Lines changed: 109 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Testing functions exposed to the user API"""
22
from collections import OrderedDict
3+
from typing import Hashable, Union
34

45
import numpy as np
56
import pandas as pd
67

7-
from xarray.core import duck_array_ops, formatting
8+
from xarray.core import duck_array_ops
9+
from xarray.core import formatting
10+
from xarray.core.dataarray import DataArray
11+
from xarray.core.dataset import Dataset
12+
from xarray.core.variable import IndexVariable, Variable
813
from xarray.core.indexes import default_indexes
914

1015

@@ -48,12 +53,11 @@ def assert_equal(a, b):
4853
assert_identical, assert_allclose, Dataset.equals, DataArray.equals,
4954
numpy.testing.assert_array_equal
5055
"""
51-
import xarray as xr
5256
__tracebackhide__ = True # noqa: F841
5357
assert type(a) == type(b) # noqa
54-
if isinstance(a, (xr.Variable, xr.DataArray)):
58+
if isinstance(a, (Variable, DataArray)):
5559
assert a.equals(b), formatting.diff_array_repr(a, b, 'equals')
56-
elif isinstance(a, xr.Dataset):
60+
elif isinstance(a, Dataset):
5761
assert a.equals(b), formatting.diff_dataset_repr(a, b, 'equals')
5862
else:
5963
raise TypeError('{} not supported by assertion comparison'
@@ -77,15 +81,14 @@ def assert_identical(a, b):
7781
--------
7882
assert_equal, assert_allclose, Dataset.equals, DataArray.equals
7983
"""
80-
import xarray as xr
8184
__tracebackhide__ = True # noqa: F841
8285
assert type(a) == type(b) # noqa
83-
if isinstance(a, xr.Variable):
86+
if isinstance(a, Variable):
8487
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
85-
elif isinstance(a, xr.DataArray):
88+
elif isinstance(a, DataArray):
8689
assert a.name == b.name
8790
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
88-
elif isinstance(a, (xr.Dataset, xr.Variable)):
91+
elif isinstance(a, (Dataset, Variable)):
8992
assert a.identical(b), formatting.diff_dataset_repr(a, b, 'identical')
9093
else:
9194
raise TypeError('{} not supported by assertion comparison'
@@ -117,15 +120,14 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
117120
--------
118121
assert_identical, assert_equal, numpy.testing.assert_allclose
119122
"""
120-
import xarray as xr
121123
__tracebackhide__ = True # noqa: F841
122124
assert type(a) == type(b) # noqa
123125
kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes)
124-
if isinstance(a, xr.Variable):
126+
if isinstance(a, Variable):
125127
assert a.dims == b.dims
126128
allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs)
127129
assert allclose, '{}\n{}'.format(a.values, b.values)
128-
elif isinstance(a, xr.DataArray):
130+
elif isinstance(a, DataArray):
129131
assert_allclose(a.variable, b.variable, **kwargs)
130132
assert set(a.coords) == set(b.coords)
131133
for v in a.coords.variables:
@@ -135,7 +137,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
135137
b.coords[v].values, **kwargs)
136138
assert allclose, '{}\n{}'.format(a.coords[v].values,
137139
b.coords[v].values)
138-
elif isinstance(a, xr.Dataset):
140+
elif isinstance(a, Dataset):
139141
assert set(a.data_vars) == set(b.data_vars)
140142
assert set(a.coords) == set(b.coords)
141143
for k in list(a.variables) + list(a.coords):
@@ -147,14 +149,12 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
147149

148150

149151
def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
150-
import xarray as xr
151-
152152
assert isinstance(indexes, OrderedDict), indexes
153153
assert all(isinstance(v, pd.Index) for v in indexes.values()), \
154154
{k: type(v) for k, v in indexes.items()}
155155

156156
index_vars = {k for k, v in possible_coord_variables.items()
157-
if isinstance(v, xr.IndexVariable)}
157+
if isinstance(v, IndexVariable)}
158158
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
159159

160160
# Note: when we support non-default indexes, these checks should be opt-in
@@ -166,17 +166,97 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
166166
(indexes, defaults)
167167

168168

169-
def _assert_indexes_invariants(a):
170-
"""Separate helper function for checking indexes invariants only."""
171-
import xarray as xr
172-
173-
if isinstance(a, xr.DataArray):
174-
if a._indexes is not None:
175-
_assert_indexes_invariants_checks(a._indexes, a._coords, a.dims)
176-
elif isinstance(a, xr.Dataset):
177-
if a._indexes is not None:
178-
_assert_indexes_invariants_checks(
179-
a._indexes, a._variables, a._dims)
180-
elif isinstance(a, xr.Variable):
181-
# no indexes
182-
pass
169+
def _assert_variable_invariants(var: Variable, name: Hashable = None):
170+
if name is None:
171+
name_or_empty = () # type: tuple
172+
else:
173+
name_or_empty = (name,)
174+
assert isinstance(var._dims, tuple), name_or_empty + (var._dims,)
175+
assert len(var._dims) == len(var._data.shape), \
176+
name_or_empty + (var._dims, var._data.shape)
177+
assert isinstance(var._encoding, (type(None), dict)), \
178+
name_or_empty + (var._encoding,)
179+
assert isinstance(var._attrs, (type(None), OrderedDict)), \
180+
name_or_empty + (var._attrs,)
181+
182+
183+
def _assert_dataarray_invariants(da: DataArray):
184+
assert isinstance(da._variable, Variable), da._variable
185+
_assert_variable_invariants(da._variable)
186+
187+
assert isinstance(da._coords, OrderedDict), da._coords
188+
assert all(
189+
isinstance(v, Variable) for v in da._coords.values()), da._coords
190+
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), \
191+
(da.dims, {k: v.dims for k, v in da._coords.items()})
192+
assert all(isinstance(v, IndexVariable)
193+
for (k, v) in da._coords.items()
194+
if v.dims == (k,)), \
195+
{k: type(v) for k, v in da._coords.items()}
196+
for k, v in da._coords.items():
197+
_assert_variable_invariants(v, k)
198+
199+
if da._indexes is not None:
200+
_assert_indexes_invariants_checks(da._indexes, da._coords, da.dims)
201+
202+
assert da._initialized is True
203+
204+
205+
def _assert_dataset_invariants(ds: Dataset):
206+
assert isinstance(ds._variables, OrderedDict), type(ds._variables)
207+
assert all(
208+
isinstance(v, Variable) for v in ds._variables.values()), \
209+
ds._variables
210+
for k, v in ds._variables.items():
211+
_assert_variable_invariants(v, k)
212+
213+
assert isinstance(ds._coord_names, set), ds._coord_names
214+
assert ds._coord_names <= ds._variables.keys(), \
215+
(ds._coord_names, set(ds._variables))
216+
217+
assert type(ds._dims) is dict, ds._dims
218+
assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
219+
var_dims = set() # type: set
220+
for v in ds._variables.values():
221+
var_dims.update(v.dims)
222+
assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
223+
assert all(ds._dims[k] == v.sizes[k]
224+
for v in ds._variables.values()
225+
for k in v.sizes), \
226+
(ds._dims, {k: v.sizes for k, v in ds._variables.items()})
227+
assert all(isinstance(v, IndexVariable)
228+
for (k, v) in ds._variables.items()
229+
if v.dims == (k,)), \
230+
{k: type(v) for k, v in ds._variables.items() if v.dims == (k,)}
231+
assert all(v.dims == (k,)
232+
for (k, v) in ds._variables.items()
233+
if k in ds._dims), \
234+
{k: v.dims for k, v in ds._variables.items() if k in ds._dims}
235+
236+
if ds._indexes is not None:
237+
_assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims)
238+
239+
assert isinstance(ds._encoding, (type(None), dict))
240+
assert isinstance(ds._attrs, (type(None), OrderedDict))
241+
assert ds._initialized is True
242+
243+
244+
def _assert_internal_invariants(
245+
xarray_obj: Union[DataArray, Dataset, Variable],
246+
):
247+
"""Validate that an xarray object satisfies its own internal invariants.
248+
249+
This exists for the benefit of xarray's own test suite, but may be useful
250+
in external projects if they (ill-advisedly) create objects using xarray's
251+
private APIs.
252+
"""
253+
if isinstance(xarray_obj, Variable):
254+
_assert_variable_invariants(xarray_obj)
255+
elif isinstance(xarray_obj, DataArray):
256+
_assert_dataarray_invariants(xarray_obj)
257+
elif isinstance(xarray_obj, Dataset):
258+
_assert_dataset_invariants(xarray_obj)
259+
else:
260+
raise TypeError(
261+
'{} is not a supported type for xarray invariant checks'
262+
.format(type(xarray_obj)))

xarray/tests/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,21 +168,20 @@ def source_ndarray(array):
168168

169169
# Internal versions of xarray's test functions that validate additional
170170
# invariants
171-
# TODO: add more invariant checks.
172171

173172
def assert_equal(a, b):
174173
xarray.testing.assert_equal(a, b)
175-
xarray.testing._assert_indexes_invariants(a)
176-
xarray.testing._assert_indexes_invariants(b)
174+
xarray.testing._assert_internal_invariants(a)
175+
xarray.testing._assert_internal_invariants(b)
177176

178177

179178
def assert_identical(a, b):
180179
xarray.testing.assert_identical(a, b)
181-
xarray.testing._assert_indexes_invariants(a)
182-
xarray.testing._assert_indexes_invariants(b)
180+
xarray.testing._assert_internal_invariants(a)
181+
xarray.testing._assert_internal_invariants(b)
183182

184183

185184
def assert_allclose(a, b, **kwargs):
186185
xarray.testing.assert_allclose(a, b, **kwargs)
187-
xarray.testing._assert_indexes_invariants(a)
188-
xarray.testing._assert_indexes_invariants(b)
186+
xarray.testing._assert_internal_invariants(a)
187+
xarray.testing._assert_internal_invariants(b)

xarray/tests/test_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,6 +2752,11 @@ def test_delitem(self):
27522752
assert set(data.variables) == all_items - set(['var1', 'numbers'])
27532753
assert 'numbers' not in data.coords
27542754

2755+
expected = Dataset()
2756+
actual = Dataset({'y': ('x', [1, 2])})
2757+
del actual['y']
2758+
assert_identical(expected, actual)
2759+
27552760
def test_squeeze(self):
27562761
data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])})
27572762
for args in [[], [['x']], [['x', 'z']]]:

0 commit comments

Comments
 (0)