Skip to content

Error when broadcasting array API compliant class #8665

Closed
@TomNicholas

Description

@TomNicholas

What happened?

Broadcasting fails for array types that strictly follow the array API standard.

What did you expect to happen?

With a normal numpy array this obviously works fine.

Minimal Complete Verifiable Example

import numpy.array_api as nxp

arr = nxp.asarray([[1, 2, 3], [4, 5, 6]], dtype=np.dtype('float32'))

var = xr.Variable(data=arr, dims=['x', 'y'])

var.isel(x=0)  # this is fine

var * var.isel(x=0)  # this is not

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[31], line 1
----> 1 var * var.isel(x=0)

File ~/Documents/Work/Code/xarray/xarray/core/_typed_ops.py:487, in VariableOpsMixin.__mul__(self, other)
    486 def __mul__(self, other: VarCompatible) -> Self | T_DataArray:
--> 487     return self._binary_op(other, operator.mul)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2406, in Variable._binary_op(self, other, f, reflexive)
   2404     other_data, self_data, dims = _broadcast_compat_data(other, self)
   2405 else:
-> 2406     self_data, other_data, dims = _broadcast_compat_data(self, other)
   2407 keep_attrs = _get_keep_attrs(default=False)
   2408 attrs = self._attrs if keep_attrs else None

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2922, in _broadcast_compat_data(self, other)
   2919 def _broadcast_compat_data(self, other):
   2920     if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
   2921         # `other` satisfies the necessary Variable API for broadcast_variables
-> 2922         new_self, new_other = _broadcast_compat_variables(self, other)
   2923         self_data = new_self.data
   2924         other_data = new_other.data

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in _broadcast_compat_variables(*variables)
   2893 """Create broadcast compatible variables, with the same dimensions.
   2894 
   2895 Unlike the result of broadcast_variables(), some variables may have
   2896 dimensions of size 1 instead of the size of the broadcast dimension.
   2897 """
   2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in <genexpr>(.0)
   2893 """Create broadcast compatible variables, with the same dimensions.
   2894 
   2895 Unlike the result of broadcast_variables(), some variables may have
   2896 dimensions of size 1 instead of the size of the broadcast dimension.
   2897 """
   2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1479, in Variable.set_dims(self, dims, shape)
   1477     expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
   1478 else:
-> 1479     expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)]
   1481 expanded_var = Variable(
   1482     expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
   1483 )
   1484 return expanded_var.transpose(*dims)

File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:555, in Array.__getitem__(self, key)
    550 """
    551 Performs the operation __getitem__.
    552 """
    553 # Note: Only indices required by the spec are allowed. See the
    554 # docstring of _validate_index
--> 555 self._validate_index(key)
    556 if isinstance(key, Array):
    557     # Indexing self._array with array_api arrays can be erroneous
    558     key = key._array

File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:348, in Array._validate_index(self, key)
    344 elif n_ellipsis == 0:
    345     # Note boolean masks must be the sole index, which we check for
    346     # later on.
    347     if not key_has_mask and n_single_axes < self.ndim:
--> 348         raise IndexError(
    349             f"{self.ndim=}, but the multi-axes index only specifies "
    350             f"{n_single_axes} dimensions. If this was intentional, "
    351             "add a trailing ellipsis (...) which expands into as many "
    352             "slices (:) as necessary - this is what np.ndarray arrays "
    353             "implicitly do, but such flat indexing behaviour is not "
    354             "specified in the Array API."
    355         )
    357 if n_ellipsis == 0:
    358     indexed_shape = self.shape

IndexError: self.ndim=1, but the multi-axes index only specifies 0 dimensions. If this was intentional, add a trailing ellipsis (...) which expands into as many slices (:) as necessary - this is what np.ndarray arrays implicitly do, but such flat indexing behaviour is not specified in the Array API.

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

No response

Environment

main branch of xarray, numpy 1.26.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    array API standardSupport for the Python array API standardbugtopic-arraysrelated to flexible array support

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions