Skip to content

Commit f8fd295

Browse files
authored
More explicit index handling in dataset.py (#2816)
* More explicit index handling in dataset.py The only part that's left is explicit index handling in merge. * Fix indexes in Dataset.diff
1 parent 9c0fb6c commit f8fd295

File tree

2 files changed

+75
-29
lines changed

2 files changed

+75
-29
lines changed

xarray/core/dataset.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from .coordinates import (DatasetCoordinates, LevelCoordinatesSource,
2222
assert_coordinate_consistent, remap_label_indexers)
2323
from .duck_array_ops import datetime_to_numeric
24-
from .indexes import Indexes, default_indexes, isel_variable_and_index
25-
from .merge import (dataset_merge_method, dataset_update_method,
26-
merge_data_and_coords, merge_variables)
24+
from .indexes import (
25+
Indexes, default_indexes, isel_variable_and_index, roll_index,
26+
)
27+
from .merge import (
28+
dataset_merge_method, dataset_update_method, merge_data_and_coords,
29+
merge_variables)
2730
from .options import OPTIONS, _get_keep_attrs
2831
from .pycompat import TYPE_CHECKING, dask_array_type
2932
from .utils import (Frozen, SortedKeysDict, _check_inplace,
@@ -2693,17 +2696,18 @@ def reorder_levels(self, dim_order=None, inplace=None,
26932696
inplace = _check_inplace(inplace)
26942697
dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs,
26952698
'reorder_levels')
2696-
replace_variables = {}
2699+
variables = self._variables.copy()
2700+
indexes = OrderedDict(self.indexes)
26972701
for dim, order in dim_order.items():
26982702
coord = self._variables[dim]
2699-
index = coord.to_index()
2703+
index = self.indexes[dim]
27002704
if not isinstance(index, pd.MultiIndex):
27012705
raise ValueError("coordinate %r has no MultiIndex" % dim)
2702-
replace_variables[dim] = IndexVariable(coord.dims,
2703-
index.reorder_levels(order))
2704-
variables = self._variables.copy()
2705-
variables.update(replace_variables)
2706-
return self._replace_vars_and_dims(variables, inplace=inplace)
2706+
new_index = index.reorder_levels(order)
2707+
variables[dim] = IndexVariable(coord.dims, new_index)
2708+
indexes[dim] = new_index
2709+
2710+
return self._replace(variables, indexes=indexes, inplace=inplace)
27072711

27082712
def _stack_once(self, dims, new_dim):
27092713
variables = OrderedDict()
@@ -2733,7 +2737,12 @@ def _stack_once(self, dims, new_dim):
27332737

27342738
coord_names = set(self._coord_names) - set(dims) | set([new_dim])
27352739

2736-
return self._replace_vars_and_dims(variables, coord_names)
2740+
indexes = OrderedDict((k, v) for k, v in self.indexes.items()
2741+
if k not in dims)
2742+
indexes[new_dim] = idx
2743+
2744+
return self._replace_with_new_dims(
2745+
variables, coord_names=coord_names, indexes=indexes)
27372746

27382747
def stack(self, dimensions=None, **dimensions_kwargs):
27392748
"""
@@ -2900,6 +2909,9 @@ def _unstack_once(self, dim):
29002909
new_dim_sizes = [lev.size for lev in index.levels]
29012910

29022911
variables = OrderedDict()
2912+
indexes = OrderedDict(
2913+
(k, v) for k, v in self.indexes.items() if k != dim)
2914+
29032915
for name, var in obj.variables.items():
29042916
if name != dim:
29052917
if dim in var.dims:
@@ -2910,10 +2922,12 @@ def _unstack_once(self, dim):
29102922

29112923
for name, lev in zip(new_dim_names, index.levels):
29122924
variables[name] = IndexVariable(name, lev)
2925+
indexes[name] = lev
29132926

29142927
coord_names = set(self._coord_names) - set([dim]) | set(new_dim_names)
29152928

2916-
return self._replace_vars_and_dims(variables, coord_names)
2929+
return self._replace_with_new_dims(
2930+
variables, coord_names=coord_names, indexes=indexes)
29172931

29182932
def unstack(self, dim=None):
29192933
"""
@@ -3098,7 +3112,10 @@ def _drop_vars(self, names, errors='raise'):
30983112
variables = OrderedDict((k, v) for k, v in self._variables.items()
30993113
if k not in drop)
31003114
coord_names = set(k for k in self._coord_names if k in variables)
3101-
return self._replace_vars_and_dims(variables, coord_names)
3115+
indexes = OrderedDict((k, v) for k, v in self.indexes.items()
3116+
if k not in drop)
3117+
return self._replace_with_new_dims(
3118+
variables, coord_names=coord_names, indexes=indexes)
31023119

31033120
def drop_dims(self, drop_dims, *, errors='raise'):
31043121
"""Drop dimensions and associated variables from this dataset.
@@ -3133,12 +3150,7 @@ def drop_dims(self, drop_dims, *, errors='raise'):
31333150

31343151
drop_vars = set(k for k, v in self._variables.items()
31353152
for d in v.dims if d in drop_dims)
3136-
3137-
variables = OrderedDict((k, v) for k, v in self._variables.items()
3138-
if k not in drop_vars)
3139-
coord_names = set(k for k in self._coord_names if k in variables)
3140-
3141-
return self._replace_with_new_dims(variables, coord_names)
3153+
return self._drop_vars(drop_vars)
31423154

31433155
def transpose(self, *dims):
31443156
"""Return a new Dataset object with all array dimensions transposed.
@@ -3457,8 +3469,11 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
34573469
**kwargs)
34583470

34593471
coord_names = set(k for k in self.coords if k in variables)
3472+
indexes = OrderedDict((k, v) for k, v in self.indexes.items()
3473+
if k in variables)
34603474
attrs = self.attrs if keep_attrs else None
3461-
return self._replace_vars_and_dims(variables, coord_names, attrs=attrs)
3475+
return self._replace_with_new_dims(
3476+
variables, coord_names=coord_names, attrs=attrs, indexes=indexes)
34623477

34633478
def apply(self, func, keep_attrs=None, args=(), **kwargs):
34643479
"""Apply a function over the data variables in this dataset.
@@ -3854,8 +3869,9 @@ def func(self, other):
38543869
other = other.reindex_like(self, copy=False)
38553870
g = ops.inplace_to_noninplace_op(f)
38563871
ds = self._calculate_binary_op(g, other, inplace=True)
3857-
self._replace_vars_and_dims(ds._variables, ds._coord_names,
3858-
attrs=ds._attrs, inplace=True)
3872+
self._replace_with_new_dims(ds._variables, ds._coord_names,
3873+
attrs=ds._attrs, indexes=ds._indexes,
3874+
inplace=True)
38593875
return self
38603876

38613877
return func
@@ -3980,7 +3996,11 @@ def diff(self, dim, n=1, label='upper'):
39803996
else:
39813997
variables[name] = var
39823998

3983-
difference = self._replace_vars_and_dims(variables)
3999+
indexes = OrderedDict(self.indexes)
4000+
if dim in indexes:
4001+
indexes[dim] = indexes[dim][kwargs_new[dim]]
4002+
4003+
difference = self._replace_with_new_dims(variables, indexes=indexes)
39844004

39854005
if n > 1:
39864006
return difference.diff(dim, n - 1)
@@ -4042,7 +4062,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
40424062
else:
40434063
variables[name] = var
40444064

4045-
return self._replace_vars_and_dims(variables)
4065+
return self._replace(variables)
40464066

40474067
def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
40484068
"""Roll this dataset by an offset along one or more dimensions.
@@ -4109,7 +4129,16 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
41094129
else:
41104130
variables[k] = v
41114131

4112-
return self._replace_vars_and_dims(variables)
4132+
if roll_coords:
4133+
indexes = OrderedDict()
4134+
for k, v in self.indexes.items():
4135+
(dim,) = self.variables[k].dims
4136+
if dim in shifts:
4137+
indexes[k] = roll_index(v, shifts[dim])
4138+
else:
4139+
indexes = OrderedDict(self.indexes)
4140+
4141+
return self._replace(variables, indexes=indexes)
41134142

41144143
def sortby(self, variables, ascending=True):
41154144
"""
@@ -4251,10 +4280,14 @@ def quantile(self, q, dim=None, interpolation='linear',
42514280

42524281
# construct the new dataset
42534282
coord_names = set(k for k in self.coords if k in variables)
4283+
indexes = OrderedDict(
4284+
(k, v) for k, v in self.indexes.items() if k in variables
4285+
)
42544286
if keep_attrs is None:
42554287
keep_attrs = _get_keep_attrs(default=False)
42564288
attrs = self.attrs if keep_attrs else None
4257-
new = self._replace_vars_and_dims(variables, coord_names, attrs=attrs)
4289+
new = self._replace_with_new_dims(
4290+
variables, coord_names=coord_names, attrs=attrs, indexes=indexes)
42584291
if 'quantile' in new.dims:
42594292
new.coords['quantile'] = Variable('quantile', q)
42604293
else:
@@ -4305,7 +4338,7 @@ def rank(self, dim, pct=False, keep_attrs=None):
43054338
if keep_attrs is None:
43064339
keep_attrs = _get_keep_attrs(default=False)
43074340
attrs = self.attrs if keep_attrs else None
4308-
return self._replace_vars_and_dims(variables, coord_names, attrs=attrs)
4341+
return self._replace(variables, coord_names, attrs=attrs)
43094342

43104343
def differentiate(self, coord, edge_order=1, datetime_unit=None):
43114344
""" Differentiate with the second order accurate central
@@ -4363,7 +4396,7 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None):
43634396
variables[k] = Variable(v.dims, grad)
43644397
else:
43654398
variables[k] = v
4366-
return self._replace_vars_and_dims(variables)
4399+
return self._replace(variables)
43674400

43684401
def integrate(self, coord, datetime_unit=None):
43694402
""" integrate the array with the trapezoidal rule.
@@ -4435,7 +4468,11 @@ def _integrate_one(self, coord, datetime_unit=None):
44354468
variables[k] = Variable(v_dims, integ)
44364469
else:
44374470
variables[k] = v
4438-
return self._replace_vars_and_dims(variables, coord_names=coord_names)
4471+
indexes = OrderedDict(
4472+
(k, v) for k, v in self.indexes.items() if k in variables
4473+
)
4474+
return self._replace_with_new_dims(
4475+
variables, coord_names=coord_names, indexes=indexes)
44394476

44404477
@property
44414478
def real(self):

xarray/core/indexes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,12 @@ def isel_variable_and_index(
8787
indexer = indexer.data
8888
new_index = index[indexer]
8989
return new_variable, new_index
90+
91+
92+
def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index:
93+
"""Roll an pandas.Index."""
94+
count %= index.shape[0]
95+
if count != 0:
96+
return index[-count:].append(index[:-count])
97+
else:
98+
return index[:]

0 commit comments

Comments
 (0)