From d9e4fd242cec742a3a818aeed9d752951d4f1472 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 7 May 2025 13:54:52 +0200 Subject: [PATCH 01/10] alignment of n-dimensional indexes vs. exclude dim Support checking exact alignment of indexes that use multiple dimensions in the cases where some of those dimensions are included in alignment while others are excluded. Added `exclude_dims` keyword argument to `Index.equals()` (and still support old signature with future warning). Also fixed bug: indexes associated with scalar coordinates were ignored during alignment. Added tests as well. --- xarray/core/indexes.py | 66 +++++++++++++++++++++++--- xarray/structure/alignment.py | 49 +++++++++++--------- xarray/tests/test_dataset.py | 87 +++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 27 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8babb885a5e..f290ef2bfd1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -2,9 +2,10 @@ import collections.abc import copy +import inspect from collections import defaultdict -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload import numpy as np import pandas as pd @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self, other: Index) -> bool: + @overload + def equals(self, other: Index) -> bool: ... + + @overload + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: ... + + def equals(self, other: Index, **kwargs) -> bool: """Compare this index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -357,6 +366,16 @@ def equals(self, other: Index) -> bool: ---------- other : Index The other Index object to compare with this object. + exclude_dims : frozenset of hashable, optional + All the dimensions that are excluded from alignment, or None by default + (i.e., when this method is not called in the context of alignment). + For a n-dimensional index it allows ignoring any relevant dimension found + in ``exclude_dims`` when comparing this index with the other index. + For a 1-dimensional index it can be always safely ignored as this + method is not called when all of the index's dimensions are also excluded + from alignment + (note: the index's dimensions correspond to the union of the dimensions + of all coordinate variables associated with this index). Returns ------- @@ -863,7 +882,7 @@ def sel( return IndexSelResult({self.dim: indexer}) - def equals(self, other: Index): + def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None): if not isinstance(other, PandasIndex): return False return self.index.equals(other.index) and self.dim == other.dim @@ -1542,7 +1561,9 @@ def sel( return IndexSelResult(results) - def equals(self, other: Index) -> bool: + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: if not isinstance(other, CoordinateTransformIndex): return False return self.transform.equals(other.transform) @@ -1925,6 +1946,36 @@ def default_indexes( return indexes +def _wrap_index_equals( + index: Index, +) -> Callable[[Index, frozenset[Hashable]], bool]: + # TODO: remove this Index.equals() wrapper (backward compatibility) + + sig = inspect.signature(index.equals) + + if len(sig.parameters) == 1: + index_cls_name = type(index).__module__ + "." + type(index).__qualname__ + emit_user_level_warning( + f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " + f"Please update it to " + f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " + "or kindly ask the maintainers doing it. " + "See documentation of xarray.Index.equals() for more info.", + FutureWarning, + ) + exclude_dims_kwarg = False + else: + exclude_dims_kwarg = True + + def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool: + if exclude_dims_kwarg: + return index.equals(other, exclude_dims=exclude_dims) + else: + return index.equals(other) + + return equals_wrapper + + def indexes_equal( index: Index, other_index: Index, @@ -1966,6 +2017,7 @@ def indexes_equal( def indexes_all_equal( elements: Sequence[tuple[Index, dict[Hashable, Variable]]], + exclude_dims: frozenset[Hashable], ) -> bool: """Check if indexes are all equal. @@ -1990,9 +2042,11 @@ def check_variables(): same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) if same_type: + index_equals_func = _wrap_index_equals(indexes[0]) try: not_equal = any( - not indexes[0].equals(other_idx) for other_idx in indexes[1:] + not index_equals_func(other_idx, exclude_dims) + for other_idx in indexes[1:] ) except NotImplementedError: not_equal = check_variables() diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index ea90519143c..49d2709343e 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -216,30 +216,37 @@ def _normalize_indexes( normalized_indexes = {} normalized_index_vars = {} - for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): - coord_names_and_dims = [] - all_dims: set[Hashable] = set() - for name, var in index_vars.items(): + for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + idx_coord_names_and_dims = [] + idx_all_dims: set[Hashable] = set() + + for name, var in idx_vars.items(): dims = var.dims - coord_names_and_dims.append((name, dims)) - all_dims.update(dims) - - exclude_dims = all_dims & self.exclude_dims - if exclude_dims == all_dims: - continue - elif exclude_dims: - excl_dims_str = ", ".join(str(d) for d in exclude_dims) - incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) - raise AlignmentError( - f"cannot exclude dimension(s) {excl_dims_str} from alignment because " - "these are used by an index together with non-excluded dimensions " - f"{incl_dims_str}" - ) + idx_coord_names_and_dims.append((name, dims)) + idx_all_dims.update(dims) + + # We can ignore an index if all the dimensions it uses are also excluded + # from the alignment (do not ignore the index if it has no related dimension, i.e., + # it is associated with one or more scalar coordinates). + if idx_all_dims: + exclude_dims = idx_all_dims & self.exclude_dims + if exclude_dims == idx_all_dims: + continue + elif exclude_dims and self.join != "exact": + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join( + str(d) for d in idx_all_dims - exclude_dims + ) + raise AlignmentError( + f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment " + "because these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" + ) - key = (tuple(coord_names_and_dims), type(idx)) + key = (tuple(idx_coord_names_and_dims), type(idx)) normalized_indexes[key] = idx - normalized_index_vars[key] = index_vars + normalized_index_vars[key] = idx_vars return normalized_indexes, normalized_index_vars @@ -298,7 +305,7 @@ def _need_reindex(self, dim, cmp_indexes) -> bool: pandas). This is useful, e.g., for overwriting such duplicate indexes. """ - if not indexes_all_equal(cmp_indexes): + if not indexes_all_equal(cmp_indexes, self.exclude_dims): # always reindex when matching indexes are not equal return True diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ac186a7d351..c001b3b69fc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2615,6 +2615,93 @@ def test_align_index_var_attrs(self, join) -> None: assert ds.x.attrs == {"units": "m"} assert ds_noattr.x.attrs == {} + def test_align_scalar_index(self) -> None: + # ensure that indexes associated with scalar coordinates are not ignored + # during alignment + class ScalarIndex(Index): + def __init__(self, value: int): + self.value = value + + @classmethod + def from_variables(cls, variables, *, options): + var = next(iter(variables.values())) + return cls(int(var.values)) + + def equals(self, other, *, exclude_dims=None): + return isinstance(other, ScalarIndex) and other.value == self.value + + ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + + actual = xr.align(ds1, ds2, join="exact") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + ds3 = Dataset(coords={"x": 1}).set_xindex("x", ScalarIndex) + + with pytest.raises(AlignmentError, match="cannot align objects"): + xr.align(ds1, ds3, join="exact") + + def test_align_multi_dim_index_exclude_dims(self) -> None: + class XYIndex(Index): + def __init__(self, x: PandasIndex, y: PandasIndex): + self.x: PandasIndex = x + self.y: PandasIndex = y + + @classmethod + def from_variables(cls, variables, *, options): + return cls( + x=PandasIndex.from_variables( + {"x": variables["x"]}, options=options + ), + y=PandasIndex.from_variables( + {"y": variables["y"]}, options=options + ), + ) + + def equals(self, other, exclude_dims=None): + x_eq = True if self.x.dim in exclude_dims else self.x.equals(other.x) + y_eq = True if self.y.dim in exclude_dims else self.y.equals(other.y) + return x_eq and y_eq + + ds1 = ( + Dataset(coords={"x": [1, 2], "y": [3, 4]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + ds2 = ( + Dataset(coords={"x": [1, 2], "y": [5, 6]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + + actual = xr.align(ds1, ds2, join="exact", exclude="y") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + with pytest.raises( + AlignmentError, match="cannot align objects.*index.*not equal" + ): + xr.align(ds1, ds2, join="exact") + + with pytest.raises(AlignmentError, match="cannot exclude dimension"): + xr.align(ds1, ds2, join="outer", exclude="y") + + def test_align_index_equals_future_warning(self) -> None: + # TODO: remove this test once the deprecation cycle is completed + class DeprecatedEqualsSignatureIndex(PandasIndex): + def equals(self, other: Index) -> bool: # type: ignore[override] + return super().equals(other, exclude_dims=None) + + ds = ( + Dataset(coords={"x": [1, 2]}) + .drop_indexes("x") + .set_xindex("x", DeprecatedEqualsSignatureIndex) + ) + + with pytest.warns(FutureWarning, match="signature.*deprecated"): + xr.align(ds, ds.copy(), join="exact") + def test_broadcast(self) -> None: ds = Dataset( {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])} From 42e33c20f786228f2f9fd07608d42fa660a9a5bd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 12 May 2025 08:36:49 +0200 Subject: [PATCH 02/10] Update xarray/core/indexes.py Co-authored-by: Deepak Cherian --- xarray/core/indexes.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f290ef2bfd1..39dfe9e1ba1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -367,14 +367,13 @@ def equals(self, other: Index, **kwargs) -> bool: other : Index The other Index object to compare with this object. exclude_dims : frozenset of hashable, optional - All the dimensions that are excluded from alignment, or None by default + Dimensions excluded from checking. It is None by default, (i.e., when this method is not called in the context of alignment). - For a n-dimensional index it allows ignoring any relevant dimension found - in ``exclude_dims`` when comparing this index with the other index. - For a 1-dimensional index it can be always safely ignored as this - method is not called when all of the index's dimensions are also excluded - from alignment - (note: the index's dimensions correspond to the union of the dimensions + For a n-dimensional index this option allows an Index to optionally, + ignore any dimension in ``exclude_dims`` when comparing + ``self`` with ``other``. For a 1-dimensional index this kwarg can be safely ignored , + as this method is not called when all of the index's dimensions are also excluded + from alignment (note: the index's dimensions correspond to the union of the dimensions of all coordinate variables associated with this index). Returns From 466214db1fce2a8cf698d9834ced465c338c51b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 May 2025 06:37:12 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/indexes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 39dfe9e1ba1..528fb8ac661 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -369,8 +369,8 @@ def equals(self, other: Index, **kwargs) -> bool: exclude_dims : frozenset of hashable, optional Dimensions excluded from checking. It is None by default, (i.e., when this method is not called in the context of alignment). - For a n-dimensional index this option allows an Index to optionally, - ignore any dimension in ``exclude_dims`` when comparing + For a n-dimensional index this option allows an Index to optionally, + ignore any dimension in ``exclude_dims`` when comparing ``self`` with ``other``. For a 1-dimensional index this kwarg can be safely ignored , as this method is not called when all of the index's dimensions are also excluded from alignment (note: the index's dimensions correspond to the union of the dimensions From 8b144d9ce60bbdb3a6cbde7c7749b3c3573cdb6d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 12 May 2025 08:38:43 +0200 Subject: [PATCH 04/10] Update xarray/core/indexes.py Co-authored-by: Deepak Cherian --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 528fb8ac661..7e6e3264285 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1958,7 +1958,7 @@ def _wrap_index_equals( f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " f"Please update it to " f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " - "or kindly ask the maintainers doing it. " + "or kindly ask the maintainers of ``{index_cls_name}`` to do it. " "See documentation of xarray.Index.equals() for more info.", FutureWarning, ) From b9f95bb685e11057fe439ac2ef18df2efa9dd9a0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 12 May 2025 08:47:56 +0200 Subject: [PATCH 05/10] rename exclude_dims -> exclude --- xarray/core/indexes.py | 38 +++++++++++++++++++----------------- xarray/tests/test_dataset.py | 10 +++++----- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7e6e3264285..b3db9af9ef5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -354,7 +354,7 @@ def equals(self, other: Index) -> bool: ... @overload def equals( - self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + self, other: Index, *, exclude: frozenset[Hashable] | None = None ) -> bool: ... def equals(self, other: Index, **kwargs) -> bool: @@ -366,20 +366,22 @@ def equals(self, other: Index, **kwargs) -> bool: ---------- other : Index The other Index object to compare with this object. - exclude_dims : frozenset of hashable, optional - Dimensions excluded from checking. It is None by default, - (i.e., when this method is not called in the context of alignment). - For a n-dimensional index this option allows an Index to optionally, - ignore any dimension in ``exclude_dims`` when comparing - ``self`` with ``other``. For a 1-dimensional index this kwarg can be safely ignored , - as this method is not called when all of the index's dimensions are also excluded - from alignment (note: the index's dimensions correspond to the union of the dimensions - of all coordinate variables associated with this index). + exclude : frozenset of hashable, optional + Dimensions excluded from checking. It is None by default, (i.e., + when this method is not called in the context of alignment). For a + n-dimensional index this option allows an Index to optionally ignore + any dimension in ``exclude`` when comparing ``self`` with ``other``. + For a 1-dimensional index this kwarg can be safely ignored, as this + method is not called when all of the index's dimensions are also + excluded from alignment (note: the index's dimensions correspond to + the union of the dimensions of all coordinate variables associated + with this index). Returns ------- is_equal : bool ``True`` if the indexes are equal, ``False`` otherwise. + """ raise NotImplementedError() @@ -881,7 +883,7 @@ def sel( return IndexSelResult({self.dim: indexer}) - def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None): + def equals(self, other: Index, *, exclude: frozenset[Hashable] | None = None): if not isinstance(other, PandasIndex): return False return self.index.equals(other.index) and self.dim == other.dim @@ -1561,7 +1563,7 @@ def sel( return IndexSelResult(results) def equals( - self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + self, other: Index, *, exclude: frozenset[Hashable] | None = None ) -> bool: if not isinstance(other, CoordinateTransformIndex): return False @@ -1957,18 +1959,18 @@ def _wrap_index_equals( emit_user_level_warning( f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " f"Please update it to " - f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " + f"``{index_cls_name}.equals(self, other, *, exclude=None)`` " "or kindly ask the maintainers of ``{index_cls_name}`` to do it. " "See documentation of xarray.Index.equals() for more info.", FutureWarning, ) - exclude_dims_kwarg = False + exclude_kwarg = False else: - exclude_dims_kwarg = True + exclude_kwarg = True - def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool: - if exclude_dims_kwarg: - return index.equals(other, exclude_dims=exclude_dims) + def equals_wrapper(other: Index, exclude: frozenset[Hashable]) -> bool: + if exclude_kwarg: + return index.equals(other, exclude=exclude) else: return index.equals(other) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c001b3b69fc..00ab6730218 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2627,7 +2627,7 @@ def from_variables(cls, variables, *, options): var = next(iter(variables.values())) return cls(int(var.values)) - def equals(self, other, *, exclude_dims=None): + def equals(self, other, *, exclude=None): return isinstance(other, ScalarIndex) and other.value == self.value ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) @@ -2659,9 +2659,9 @@ def from_variables(cls, variables, *, options): ), ) - def equals(self, other, exclude_dims=None): - x_eq = True if self.x.dim in exclude_dims else self.x.equals(other.x) - y_eq = True if self.y.dim in exclude_dims else self.y.equals(other.y) + def equals(self, other, exclude=None): + x_eq = True if self.x.dim in exclude else self.x.equals(other.x) + y_eq = True if self.y.dim in exclude else self.y.equals(other.y) return x_eq and y_eq ds1 = ( @@ -2691,7 +2691,7 @@ def test_align_index_equals_future_warning(self) -> None: # TODO: remove this test once the deprecation cycle is completed class DeprecatedEqualsSignatureIndex(PandasIndex): def equals(self, other: Index) -> bool: # type: ignore[override] - return super().equals(other, exclude_dims=None) + return super().equals(other, exclude=None) ds = ( Dataset(coords={"x": [1, 2]}) From 299de61a694e1f8ba92fd32d2fbe67f16b1369f9 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 12 May 2025 09:36:47 +0200 Subject: [PATCH 06/10] align: refactor collecting indexes Better readability. --- xarray/structure/alignment.py | 103 ++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 43 deletions(-) diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index 49d2709343e..5d1b7148ef9 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -94,10 +94,47 @@ def reindex_variables( return new_variables +def _normalize_indexes( + indexes: Mapping[Any, Any | T_DuckArray], +) -> Indexes: + """Normalize the indexes/indexers given for re-indexing or alignment. + + Wrap any arbitrary array or `pandas.Index` as an Xarray `PandasIndex` + and create the index variable(s). + + """ + xr_indexes: dict[Hashable, Index] = {} + xr_variables: dict[Hashable, Variable] + + if isinstance(indexes, Indexes): + xr_variables = dict(indexes.variables) + else: + xr_variables = {} + + for k, idx in indexes.items(): + if not isinstance(idx, Index): + if getattr(idx, "dims", (k,)) != (k,): + raise AlignmentError( + f"Indexer has dimensions {idx.dims} that are different " + f"from that to be indexed along '{k}'" + ) + data: T_DuckArray = as_compatible_data(idx) + pd_idx = safe_cast_to_index(data) + pd_idx.name = k + if isinstance(pd_idx, pd.MultiIndex): + idx = PandasMultiIndex(pd_idx, k) + else: + idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype) + xr_variables.update(idx.create_variables()) + xr_indexes[k] = idx + + return Indexes(xr_indexes, xr_variables) + + CoordNamesAndDims = tuple[tuple[Hashable, tuple[Hashable, ...]], ...] MatchingIndexKey = tuple[CoordNamesAndDims, type[Index]] -NormalizedIndexes = dict[MatchingIndexKey, Index] -NormalizedIndexVars = dict[MatchingIndexKey, dict[Hashable, Variable]] +IndexesToAlign = dict[MatchingIndexKey, Index] +IndexVarsToAlign = dict[MatchingIndexKey, dict[Hashable, Variable]] class Aligner(Generic[T_Alignable]): @@ -169,7 +206,9 @@ def __init__( if indexes is None: indexes = {} - self.indexes, self.index_vars = self._normalize_indexes(indexes) + self.indexes, self.index_vars = self._collect_indexes( + _normalize_indexes(indexes) + ) self.all_indexes = {} self.all_index_vars = {} @@ -181,43 +220,21 @@ def __init__( self.results = tuple() - def _normalize_indexes( - self, - indexes: Mapping[Any, Any | T_DuckArray], - ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: - """Normalize the indexes/indexers used for re-indexing or alignment. + def _collect_indexes( + self, indexes: Indexes + ) -> tuple[IndexesToAlign, IndexVarsToAlign]: + """Collect input and/or object indexes for alignment. - Return dictionaries of xarray Index objects and coordinate variables - such that we can group matching indexes based on the dictionary keys. + Return new dictionaries of xarray Index objects and coordinate + variables, whose keys are used to later retrieve all the indexes to + compare with each other (based on the name and dimensions of their + associated coordinate variables as well as the Index type). """ - if isinstance(indexes, Indexes): - xr_variables = dict(indexes.variables) - else: - xr_variables = {} - - xr_indexes: dict[Hashable, Index] = {} - for k, idx in indexes.items(): - if not isinstance(idx, Index): - if getattr(idx, "dims", (k,)) != (k,): - raise AlignmentError( - f"Indexer has dimensions {idx.dims} that are different " - f"from that to be indexed along '{k}'" - ) - data: T_DuckArray = as_compatible_data(idx) - pd_idx = safe_cast_to_index(data) - pd_idx.name = k - if isinstance(pd_idx, pd.MultiIndex): - idx = PandasMultiIndex(pd_idx, k) - else: - idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype) - xr_variables.update(idx.create_variables()) - xr_indexes[k] = idx - - normalized_indexes = {} - normalized_index_vars = {} + collected_indexes = {} + collected_index_vars = {} - for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + for idx, idx_vars in indexes.group_by_index(): idx_coord_names_and_dims = [] idx_all_dims: set[Hashable] = set() @@ -226,8 +243,8 @@ def _normalize_indexes( idx_coord_names_and_dims.append((name, dims)) idx_all_dims.update(dims) - # We can ignore an index if all the dimensions it uses are also excluded - # from the alignment (do not ignore the index if it has no related dimension, i.e., + # Do not collect an index if all the dimensions it uses are also excluded + # from the alignment (always collect the index if it has no related dimension, i.e., # it is associated with one or more scalar coordinates). if idx_all_dims: exclude_dims = idx_all_dims & self.exclude_dims @@ -244,11 +261,11 @@ def _normalize_indexes( f"{incl_dims_str}" ) - key = (tuple(idx_coord_names_and_dims), type(idx)) - normalized_indexes[key] = idx - normalized_index_vars[key] = idx_vars + key: MatchingIndexKey = (tuple(idx_coord_names_and_dims), type(idx)) + collected_indexes[key] = idx + collected_index_vars[key] = idx_vars - return normalized_indexes, normalized_index_vars + return collected_indexes, collected_index_vars def find_matching_indexes(self) -> None: all_indexes: dict[MatchingIndexKey, list[Index]] @@ -262,7 +279,7 @@ def find_matching_indexes(self) -> None: objects_matching_indexes = [] for obj in self.objects: - obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes) + obj_indexes, obj_index_vars = self._collect_indexes(obj.xindexes) objects_matching_indexes.append(obj_indexes) for key, idx in obj_indexes.items(): all_indexes[key].append(idx) From d185f6f1ee4c8bf5b6c44ee011d76e9f4fbcd28f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 12 May 2025 12:11:24 +0200 Subject: [PATCH 07/10] refactor index assignment in aligned objects --- xarray/structure/alignment.py | 101 +++++++++++++++++++++++++--------- xarray/tests/test_dataset.py | 2 +- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index 5d1b7148ef9..88f9835df7b 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -100,7 +100,7 @@ def _normalize_indexes( """Normalize the indexes/indexers given for re-indexing or alignment. Wrap any arbitrary array or `pandas.Index` as an Xarray `PandasIndex` - and create the index variable(s). + associated with its corresponding dimension coordinate variable. """ xr_indexes: dict[Hashable, Index] = {} @@ -153,6 +153,9 @@ class Aligner(Generic[T_Alignable]): objects: tuple[T_Alignable, ...] results: tuple[T_Alignable, ...] objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] + objects_matching_index_vars: tuple[ + dict[MatchingIndexKey, dict[Hashable, Variable]], ... + ] join: str exclude_dims: frozenset[Hashable] exclude_vars: frozenset[Hashable] @@ -166,6 +169,7 @@ class Aligner(Generic[T_Alignable]): aligned_indexes: dict[MatchingIndexKey, Index] aligned_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] reindex: dict[MatchingIndexKey, bool] + keep_original_indexes: set[MatchingIndexKey] reindex_kwargs: dict[str, Any] unindexed_dim_sizes: dict[Hashable, set] new_indexes: Indexes[Index] @@ -185,6 +189,7 @@ def __init__( ): self.objects = tuple(objects) self.objects_matching_indexes = () + self.objects_matching_index_vars = () if join not in ["inner", "outer", "override", "exact", "left", "right"]: raise ValueError(f"invalid value for join: {join}") @@ -217,6 +222,7 @@ def __init__( self.aligned_indexes = {} self.aligned_index_vars = {} self.reindex = {} + self.keep_original_indexes = set() self.results = tuple() @@ -243,25 +249,34 @@ def _collect_indexes( idx_coord_names_and_dims.append((name, dims)) idx_all_dims.update(dims) - # Do not collect an index if all the dimensions it uses are also excluded - # from the alignment (always collect the index if it has no related dimension, i.e., - # it is associated with one or more scalar coordinates). + key: MatchingIndexKey = (tuple(idx_coord_names_and_dims), type(idx)) + if idx_all_dims: exclude_dims = idx_all_dims & self.exclude_dims if exclude_dims == idx_all_dims: + # Do not collect an index if all the dimensions it uses are + # also excluded from the alignment continue - elif exclude_dims and self.join != "exact": - excl_dims_str = ", ".join(str(d) for d in exclude_dims) - incl_dims_str = ", ".join( - str(d) for d in idx_all_dims - exclude_dims - ) - raise AlignmentError( - f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment " - "because these are used by an index together with non-excluded dimensions " - f"{incl_dims_str}" - ) + elif exclude_dims: + # If the dimensions used by index partially overlap with the dimensions + # excluded from alignment, it is possible to check index equality along + # non-excluded dimensions only. However, in this case each of the aligned + # objects must retain (a copy of) their original index. Re-indexing and + # overriding the index are not supported. + if self.join == "override": + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join( + str(d) for d in idx_all_dims - exclude_dims + ) + raise AlignmentError( + f"cannot exclude dimension(s) {excl_dims_str} from alignment " + "with `join='override` because these are used by an index " + f"together with non-excluded dimensions {incl_dims_str}" + "(cannot safely override the index)." + ) + else: + self.keep_original_indexes.add(key) - key: MatchingIndexKey = (tuple(idx_coord_names_and_dims), type(idx)) collected_indexes[key] = idx collected_index_vars[key] = idx_vars @@ -272,15 +287,20 @@ def find_matching_indexes(self) -> None: all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] all_indexes_dim_sizes: dict[MatchingIndexKey, dict[Hashable, set]] objects_matching_indexes: list[dict[MatchingIndexKey, Index]] + objects_matching_index_vars: list[ + dict[MatchingIndexKey, dict[Hashable, Variable]] + ] all_indexes = defaultdict(list) all_index_vars = defaultdict(list) all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set)) objects_matching_indexes = [] + objects_matching_index_vars = [] for obj in self.objects: obj_indexes, obj_index_vars = self._collect_indexes(obj.xindexes) objects_matching_indexes.append(obj_indexes) + objects_matching_index_vars.append(obj_index_vars) for key, idx in obj_indexes.items(): all_indexes[key].append(idx) for key, index_vars in obj_index_vars.items(): @@ -289,6 +309,7 @@ def find_matching_indexes(self) -> None: all_indexes_dim_sizes[key][dim].add(size) self.objects_matching_indexes = tuple(objects_matching_indexes) + self.objects_matching_index_vars = tuple(objects_matching_index_vars) self.all_indexes = all_indexes self.all_index_vars = all_index_vars @@ -509,6 +530,13 @@ def _get_dim_pos_indexers( if self.reindex[key]: indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) for dim, idxer in indexers.items(): + if dim in self.exclude_dims: + raise AlignmentError( + f"cannot reindex or align along dimension {dim!r} because " + "it is explicitly excluded from alignment. This is likely caused by " + "wrong results returned by the `reindex_like` method of this index:\n" + f"{obj_idx!r}" + ) if dim in dim_pos_indexers and not np.array_equal( idxer, dim_pos_indexers[dim] ): @@ -526,22 +554,37 @@ def _get_indexes_and_vars( self, obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], + matching_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: new_indexes = {} new_variables = {} for key, aligned_idx in self.aligned_indexes.items(): - index_vars = self.aligned_index_vars[key] + aligned_idx_vars = self.aligned_index_vars[key] obj_idx = matching_indexes.get(key) + obj_idx_vars = matching_index_vars.get(key) + if obj_idx is None: - # add the index if it relates to unindexed dimensions in obj - index_vars_dims = {d for var in index_vars.values() for d in var.dims} - if index_vars_dims <= set(obj.dims): + # add the aligned index if it relates to unindexed dimensions in obj + dims = {d for var in aligned_idx_vars.values() for d in var.dims} + if dims <= set(obj.dims): obj_idx = aligned_idx + if obj_idx is not None: - for name, var in index_vars.items(): - new_indexes[name] = aligned_idx - new_variables[name] = var.copy(deep=self.copy) + # TODO: always copy object's index when no re-indexing is required? + # (instead of assigning the aligned index) + # (need performance assessment) + if key in self.keep_original_indexes: + assert self.reindex[key] is False + new_idx = obj_idx.copy(deep=self.copy) + new_idx_vars = new_idx.create_variables(obj_idx_vars) + else: + new_idx = aligned_idx + new_idx_vars = { + k: v.copy(deep=self.copy) for k, v in aligned_idx_vars.items() + } + new_indexes.update(dict.fromkeys(new_idx_vars, new_idx)) + new_variables.update(new_idx_vars) return new_indexes, new_variables @@ -549,8 +592,11 @@ def _reindex_one( self, obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], + matching_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]], ) -> T_Alignable: - new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) + new_indexes, new_variables = self._get_indexes_and_vars( + obj, matching_indexes, matching_index_vars + ) dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) return obj._reindex_callback( @@ -565,9 +611,12 @@ def _reindex_one( def reindex_all(self) -> None: self.results = tuple( - self._reindex_one(obj, matching_indexes) - for obj, matching_indexes in zip( - self.objects, self.objects_matching_indexes, strict=True + self._reindex_one(obj, matching_indexes, matching_index_vars) + for obj, matching_indexes, matching_index_vars in zip( + self.objects, + self.objects_matching_indexes, + self.objects_matching_index_vars, + strict=True, ) ) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 00ab6730218..7cb1c9e4682 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2685,7 +2685,7 @@ def equals(self, other, exclude=None): xr.align(ds1, ds2, join="exact") with pytest.raises(AlignmentError, match="cannot exclude dimension"): - xr.align(ds1, ds2, join="outer", exclude="y") + xr.align(ds1, ds2, join="override", exclude="y") def test_align_index_equals_future_warning(self) -> None: # TODO: remove this test once the deprecation cycle is completed From 07009fc6dd390993ccc825406b24a0d5cfdf1adb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 13 May 2025 14:00:49 +0200 Subject: [PATCH 08/10] update tests Test a bit more than join="exact". --- xarray/tests/test_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7cb1c9e4682..77617cd97c7 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2675,9 +2675,10 @@ def equals(self, other, exclude=None): .set_xindex(["x", "y"], XYIndex) ) - actual = xr.align(ds1, ds2, join="exact", exclude="y") - assert_identical(actual[0], ds1, check_default_indexes=False) - assert_identical(actual[1], ds2, check_default_indexes=False) + for join in ("outer", "exact"): + actual = xr.align(ds1, ds2, join=join, exclude="y") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) with pytest.raises( AlignmentError, match="cannot align objects.*index.*not equal" From 3bf0738ab5b07ca514a6f725ed81dd1d13d129dc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 13 May 2025 15:23:41 +0200 Subject: [PATCH 09/10] add exclude kwarg to CoordinateTransform.equals --- xarray/core/coordinate_transform.py | 30 ++++++++++++++++++++--- xarray/core/indexes.py | 2 +- xarray/indexes/range_index.py | 4 ++- xarray/tests/test_coordinate_transform.py | 4 ++- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index d9e09cea173..94b3b109e1e 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from collections.abc import Hashable, Iterable, Mapping -from typing import Any +from typing import Any, overload import numpy as np @@ -64,8 +66,30 @@ def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: """ raise NotImplementedError - def equals(self, other: "CoordinateTransform") -> bool: - """Check equality with another CoordinateTransform of the same kind.""" + @overload + def equals(self, other: CoordinateTransform) -> bool: ... + + @overload + def equals( + self, other: CoordinateTransform, *, exclude: frozenset[Hashable] | None = None + ) -> bool: ... + + def equals(self, other: CoordinateTransform, **kwargs) -> bool: + """Check equality with another CoordinateTransform of the same kind. + + Parameters + ---------- + other : CoordinateTransform + The other Index object to compare with this object. + exclude : frozenset of hashable, optional + Dimensions excluded from checking. It is None by default, (i.e., + when this method is not called in the context of alignment). For a + n-dimensional transform this option allows a CoordinateTransform to + optionally ignore any dimension in ``exclude`` when comparing + ``self`` with ``other``. For a 1-dimensional transform this kwarg + can be safely ignored, as this method is not called when all of the + transform's dimensions are also excluded from alignment. + """ raise NotImplementedError def generate_coords( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b3db9af9ef5..b9abf4fef2d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1567,7 +1567,7 @@ def equals( ) -> bool: if not isinstance(other, CoordinateTransformIndex): return False - return self.transform.equals(other.transform) + return self.transform.equals(other.transform, exclude=exclude) def rename( self, diff --git a/xarray/indexes/range_index.py b/xarray/indexes/range_index.py index 80ab95447d3..230acde828c 100644 --- a/xarray/indexes/range_index.py +++ b/xarray/indexes/range_index.py @@ -65,7 +65,9 @@ def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: positions = (labels - self.start) / self.step return {self.dim: positions} - def equals(self, other: CoordinateTransform) -> bool: + def equals( + self, other: CoordinateTransform, exclude: frozenset[Hashable] | None = None + ) -> bool: if not isinstance(other, RangeCoordinateTransform): return False diff --git a/xarray/tests/test_coordinate_transform.py b/xarray/tests/test_coordinate_transform.py index d3e0d73caab..386ce426998 100644 --- a/xarray/tests/test_coordinate_transform.py +++ b/xarray/tests/test_coordinate_transform.py @@ -32,7 +32,9 @@ def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} - def equals(self, other: "CoordinateTransform") -> bool: + def equals( + self, other: CoordinateTransform, exclude: frozenset[Hashable] | None = None + ) -> bool: if not isinstance(other, SimpleCoordinateTransform): return False return self.scale == other.scale From ddbf62a2785650c4c13eca41fdf961bc95b95a8e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 13 May 2025 15:30:39 +0200 Subject: [PATCH 10/10] update whats new --- doc/whats-new.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c8fbecf82af..941e52764ae 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,7 +24,11 @@ v2025.05.0 (unreleased) New Features ~~~~~~~~~~~~ - +- Allow an Xarray index that uses multiple dimensions checking equality with another + index for only a subset of those dimensions (i.e., ignoring the dimensions + that are excluded from alignment). + (:issue:`10243`, :pull:`10293`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~