diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 77a4728542f..a0ca5151b8d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -21,9 +21,12 @@ from .coordinates import (DatasetCoordinates, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) from .duck_array_ops import datetime_to_numeric -from .indexes import Indexes, default_indexes, isel_variable_and_index -from .merge import (dataset_merge_method, dataset_update_method, - merge_data_and_coords, merge_variables) +from .indexes import ( + Indexes, default_indexes, isel_variable_and_index, roll_index, +) +from .merge import ( + dataset_merge_method, dataset_update_method, merge_data_and_coords, + merge_variables) from .options import OPTIONS, _get_keep_attrs from .pycompat import TYPE_CHECKING, dask_array_type from .utils import (Frozen, SortedKeysDict, _check_inplace, @@ -2693,17 +2696,18 @@ def reorder_levels(self, dim_order=None, inplace=None, inplace = _check_inplace(inplace) dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, 'reorder_levels') - replace_variables = {} + variables = self._variables.copy() + indexes = OrderedDict(self.indexes) for dim, order in dim_order.items(): coord = self._variables[dim] - index = coord.to_index() + index = self.indexes[dim] if not isinstance(index, pd.MultiIndex): raise ValueError("coordinate %r has no MultiIndex" % dim) - replace_variables[dim] = IndexVariable(coord.dims, - index.reorder_levels(order)) - variables = self._variables.copy() - variables.update(replace_variables) - return self._replace_vars_and_dims(variables, inplace=inplace) + new_index = index.reorder_levels(order) + variables[dim] = IndexVariable(coord.dims, new_index) + indexes[dim] = new_index + + return self._replace(variables, indexes=indexes, inplace=inplace) def _stack_once(self, dims, new_dim): variables = OrderedDict() @@ -2733,7 +2737,12 @@ def _stack_once(self, dims, new_dim): coord_names = set(self._coord_names) - set(dims) | set([new_dim]) - return self._replace_vars_and_dims(variables, coord_names) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() + if k not in dims) + indexes[new_dim] = idx + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes) def stack(self, dimensions=None, **dimensions_kwargs): """ @@ -2900,6 +2909,9 @@ def _unstack_once(self, dim): new_dim_sizes = [lev.size for lev in index.levels] variables = OrderedDict() + indexes = OrderedDict( + (k, v) for k, v in self.indexes.items() if k != dim) + for name, var in obj.variables.items(): if name != dim: if dim in var.dims: @@ -2910,10 +2922,12 @@ def _unstack_once(self, dim): for name, lev in zip(new_dim_names, index.levels): variables[name] = IndexVariable(name, lev) + indexes[name] = lev coord_names = set(self._coord_names) - set([dim]) | set(new_dim_names) - return self._replace_vars_and_dims(variables, coord_names) + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes) def unstack(self, dim=None): """ @@ -3098,7 +3112,10 @@ def _drop_vars(self, names, errors='raise'): variables = OrderedDict((k, v) for k, v in self._variables.items() if k not in drop) coord_names = set(k for k in self._coord_names if k in variables) - return self._replace_vars_and_dims(variables, coord_names) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() + if k not in drop) + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes) def drop_dims(self, drop_dims, *, errors='raise'): """Drop dimensions and associated variables from this dataset. @@ -3133,12 +3150,7 @@ def drop_dims(self, drop_dims, *, errors='raise'): drop_vars = set(k for k, v in self._variables.items() for d in v.dims if d in drop_dims) - - variables = OrderedDict((k, v) for k, v in self._variables.items() - if k not in drop_vars) - coord_names = set(k for k in self._coord_names if k in variables) - - return self._replace_with_new_dims(variables, coord_names) + return self._drop_vars(drop_vars) def transpose(self, *dims): """Return a new Dataset object with all array dimensions transposed. @@ -3457,8 +3469,11 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, **kwargs) coord_names = set(k for k in self.coords if k in variables) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() + if k in variables) attrs = self.attrs if keep_attrs else None - return self._replace_vars_and_dims(variables, coord_names, attrs=attrs) + return self._replace_with_new_dims( + variables, coord_names=coord_names, attrs=attrs, indexes=indexes) def apply(self, func, keep_attrs=None, args=(), **kwargs): """Apply a function over the data variables in this dataset. @@ -3854,8 +3869,9 @@ def func(self, other): other = other.reindex_like(self, copy=False) g = ops.inplace_to_noninplace_op(f) ds = self._calculate_binary_op(g, other, inplace=True) - self._replace_vars_and_dims(ds._variables, ds._coord_names, - attrs=ds._attrs, inplace=True) + self._replace_with_new_dims(ds._variables, ds._coord_names, + attrs=ds._attrs, indexes=ds._indexes, + inplace=True) return self return func @@ -3980,7 +3996,11 @@ def diff(self, dim, n=1, label='upper'): else: variables[name] = var - difference = self._replace_vars_and_dims(variables) + indexes = OrderedDict(self.indexes) + if dim in indexes: + indexes[dim] = indexes[dim][kwargs_new[dim]] + + difference = self._replace_with_new_dims(variables, indexes=indexes) if n > 1: return difference.diff(dim, n - 1) @@ -4042,7 +4062,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): else: variables[name] = var - return self._replace_vars_and_dims(variables) + return self._replace(variables) def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this dataset by an offset along one or more dimensions. @@ -4109,7 +4129,16 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): else: variables[k] = v - return self._replace_vars_and_dims(variables) + if roll_coords: + indexes = OrderedDict() + for k, v in self.indexes.items(): + (dim,) = self.variables[k].dims + if dim in shifts: + indexes[k] = roll_index(v, shifts[dim]) + else: + indexes = OrderedDict(self.indexes) + + return self._replace(variables, indexes=indexes) def sortby(self, variables, ascending=True): """ @@ -4251,10 +4280,14 @@ def quantile(self, q, dim=None, interpolation='linear', # construct the new dataset coord_names = set(k for k in self.coords if k in variables) + indexes = OrderedDict( + (k, v) for k, v in self.indexes.items() if k in variables + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None - new = self._replace_vars_and_dims(variables, coord_names, attrs=attrs) + new = self._replace_with_new_dims( + variables, coord_names=coord_names, attrs=attrs, indexes=indexes) if 'quantile' in new.dims: new.coords['quantile'] = Variable('quantile', q) else: @@ -4305,7 +4338,7 @@ def rank(self, dim, pct=False, keep_attrs=None): if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None - return self._replace_vars_and_dims(variables, coord_names, attrs=attrs) + return self._replace(variables, coord_names, attrs=attrs) def differentiate(self, coord, edge_order=1, datetime_unit=None): """ Differentiate with the second order accurate central @@ -4363,7 +4396,7 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): variables[k] = Variable(v.dims, grad) else: variables[k] = v - return self._replace_vars_and_dims(variables) + return self._replace(variables) def integrate(self, coord, datetime_unit=None): """ integrate the array with the trapezoidal rule. @@ -4435,7 +4468,11 @@ def _integrate_one(self, coord, datetime_unit=None): variables[k] = Variable(v_dims, integ) else: variables[k] = v - return self._replace_vars_and_dims(variables, coord_names=coord_names) + indexes = OrderedDict( + (k, v) for k, v in self.indexes.items() if k in variables + ) + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes) @property def real(self): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index eccb72b6a58..253dbd23164 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -87,3 +87,12 @@ def isel_variable_and_index( indexer = indexer.data new_index = index[indexer] return new_variable, new_index + + +def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index: + """Roll an pandas.Index.""" + count %= index.shape[0] + if count != 0: + return index[-count:].append(index[:-count]) + else: + return index[:]