Skip to content
67 changes: 61 additions & 6 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import collections.abc
import copy
import inspect
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")

def equals(self, other: Index) -> bool:
@overload
def equals(self, other: Index) -> bool: ...

@overload
def equals(
self, other: Index, *, exclude: frozenset[Hashable] | None = None
) -> bool: ...

def equals(self, other: Index, **kwargs) -> bool:
"""Compare this index with another index of the same type.

Implementation is optional but required in order to support alignment.
Expand All @@ -357,11 +366,22 @@ def equals(self, other: Index) -> bool:
----------
other : Index
The other Index object to compare with this object.
exclude : frozenset of hashable, optional
Dimensions excluded from checking. It is None by default, (i.e.,
when this method is not called in the context of alignment). For a
n-dimensional index this option allows an Index to optionally ignore
any dimension in ``exclude`` when comparing ``self`` with ``other``.
For a 1-dimensional index this kwarg can be safely ignored, as this
method is not called when all of the index's dimensions are also
excluded from alignment (note: the index's dimensions correspond to
the union of the dimensions of all coordinate variables associated
with this index).

Returns
-------
is_equal : bool
``True`` if the indexes are equal, ``False`` otherwise.

"""
raise NotImplementedError()

Expand Down Expand Up @@ -863,7 +883,7 @@ def sel(

return IndexSelResult({self.dim: indexer})

def equals(self, other: Index):
def equals(self, other: Index, *, exclude: frozenset[Hashable] | None = None):
if not isinstance(other, PandasIndex):
return False
return self.index.equals(other.index) and self.dim == other.dim
Expand Down Expand Up @@ -1542,7 +1562,9 @@ def sel(

return IndexSelResult(results)

def equals(self, other: Index) -> bool:
def equals(
self, other: Index, *, exclude: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, CoordinateTransformIndex):
return False
return self.transform.equals(other.transform)
Expand Down Expand Up @@ -1925,6 +1947,36 @@ def default_indexes(
return indexes


def _wrap_index_equals(
index: Index,
) -> Callable[[Index, frozenset[Hashable]], bool]:
# TODO: remove this Index.equals() wrapper (backward compatibility)

sig = inspect.signature(index.equals)

if len(sig.parameters) == 1:
index_cls_name = type(index).__module__ + "." + type(index).__qualname__
emit_user_level_warning(
f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. "
f"Please update it to "
f"``{index_cls_name}.equals(self, other, *, exclude=None)`` "
"or kindly ask the maintainers of ``{index_cls_name}`` to do it. "
"See documentation of xarray.Index.equals() for more info.",
FutureWarning,
)
exclude_kwarg = False
else:
exclude_kwarg = True

def equals_wrapper(other: Index, exclude: frozenset[Hashable]) -> bool:
if exclude_kwarg:
return index.equals(other, exclude=exclude)
else:
return index.equals(other)

return equals_wrapper


def indexes_equal(
index: Index,
other_index: Index,
Expand Down Expand Up @@ -1966,6 +2018,7 @@ def indexes_equal(

def indexes_all_equal(
elements: Sequence[tuple[Index, dict[Hashable, Variable]]],
exclude_dims: frozenset[Hashable],
) -> bool:
"""Check if indexes are all equal.

Expand All @@ -1990,9 +2043,11 @@ def check_variables():

same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:])
if same_type:
index_equals_func = _wrap_index_equals(indexes[0])
try:
not_equal = any(
not indexes[0].equals(other_idx) for other_idx in indexes[1:]
not index_equals_func(other_idx, exclude_dims)
for other_idx in indexes[1:]
)
except NotImplementedError:
not_equal = check_variables()
Expand Down
Loading
Loading