Closed
Description
What happened?
Broadcasting fails for array types that strictly follow the array API standard.
What did you expect to happen?
With a normal numpy array this obviously works fine.
Minimal Complete Verifiable Example
import numpy.array_api as nxp
arr = nxp.asarray([[1, 2, 3], [4, 5, 6]], dtype=np.dtype('float32'))
var = xr.Variable(data=arr, dims=['x', 'y'])
var.isel(x=0) # this is fine
var * var.isel(x=0) # this is not
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[31], line 1
----> 1 var * var.isel(x=0)
File ~/Documents/Work/Code/xarray/xarray/core/_typed_ops.py:487, in VariableOpsMixin.__mul__(self, other)
486 def __mul__(self, other: VarCompatible) -> Self | T_DataArray:
--> 487 return self._binary_op(other, operator.mul)
File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2406, in Variable._binary_op(self, other, f, reflexive)
2404 other_data, self_data, dims = _broadcast_compat_data(other, self)
2405 else:
-> 2406 self_data, other_data, dims = _broadcast_compat_data(self, other)
2407 keep_attrs = _get_keep_attrs(default=False)
2408 attrs = self._attrs if keep_attrs else None
File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2922, in _broadcast_compat_data(self, other)
2919 def _broadcast_compat_data(self, other):
2920 if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
2921 # `other` satisfies the necessary Variable API for broadcast_variables
-> 2922 new_self, new_other = _broadcast_compat_variables(self, other)
2923 self_data = new_self.data
2924 other_data = new_other.data
File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in _broadcast_compat_variables(*variables)
2893 """Create broadcast compatible variables, with the same dimensions.
2894
2895 Unlike the result of broadcast_variables(), some variables may have
2896 dimensions of size 1 instead of the size of the broadcast dimension.
2897 """
2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)
File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in <genexpr>(.0)
2893 """Create broadcast compatible variables, with the same dimensions.
2894
2895 Unlike the result of broadcast_variables(), some variables may have
2896 dimensions of size 1 instead of the size of the broadcast dimension.
2897 """
2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)
File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1479, in Variable.set_dims(self, dims, shape)
1477 expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
1478 else:
-> 1479 expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)]
1481 expanded_var = Variable(
1482 expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
1483 )
1484 return expanded_var.transpose(*dims)
File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:555, in Array.__getitem__(self, key)
550 """
551 Performs the operation __getitem__.
552 """
553 # Note: Only indices required by the spec are allowed. See the
554 # docstring of _validate_index
--> 555 self._validate_index(key)
556 if isinstance(key, Array):
557 # Indexing self._array with array_api arrays can be erroneous
558 key = key._array
File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:348, in Array._validate_index(self, key)
344 elif n_ellipsis == 0:
345 # Note boolean masks must be the sole index, which we check for
346 # later on.
347 if not key_has_mask and n_single_axes < self.ndim:
--> 348 raise IndexError(
349 f"{self.ndim=}, but the multi-axes index only specifies "
350 f"{n_single_axes} dimensions. If this was intentional, "
351 "add a trailing ellipsis (...) which expands into as many "
352 "slices (:) as necessary - this is what np.ndarray arrays "
353 "implicitly do, but such flat indexing behaviour is not "
354 "specified in the Array API."
355 )
357 if n_ellipsis == 0:
358 indexed_shape = self.shape
IndexError: self.ndim=1, but the multi-axes index only specifies 0 dimensions. If this was intentional, add a trailing ellipsis (...) which expands into as many slices (:) as necessary - this is what np.ndarray arrays implicitly do, but such flat indexing behaviour is not specified in the Array API.
MVCE confirmation
- Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- Complete example — the example is self-contained, including all data and the text of any traceback.
- Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- New issue — a search of GitHub Issues suggests this is not a duplicate.
- Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
No response
Anything else we need to know?
No response
Environment
main branch of xarray, numpy 1.26.0
Metadata
Metadata
Assignees
Type
Projects
Status
Done