Skip to content

Commit 4d5237b

Browse files
mathausemax-sixty
authored andcommitted
enable xr.ALL_DIMS in xr.dot (#3424)
* enable xr.ALL_DIMS in xr.dot * trailing whitespace * move whats new to other ellipsis work * xr.ALL_DIMS -> Ellipsis
1 parent 80e4e89 commit 4d5237b

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ New Features
3939
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest
4040
using `...`.
4141
By `Maximilian Roos <https://github.com/max-sixty>`_
42+
- :py:func:`~xarray.dot`, and :py:func:`~xarray.DataArray.dot` now support the
43+
`dims=...` option to sum over the union of dimensions of all input arrays
44+
(:issue:`3423`) by `Mathias Hauser <https://github.com/mathause>`_.
4245
- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
4346
(:pull:`3238`) by `Justus Magin <https://github.com/keewis>`_.
4447

xarray/core/computation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,9 +1055,9 @@ def dot(*arrays, dims=None, **kwargs):
10551055
----------
10561056
arrays: DataArray (or Variable) objects
10571057
Arrays to compute.
1058-
dims: str or tuple of strings, optional
1059-
Which dimensions to sum over.
1060-
If not speciified, then all the common dimensions are summed over.
1058+
dims: '...', str or tuple of strings, optional
1059+
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
1060+
If not specified, then all the common dimensions are summed over.
10611061
**kwargs: dict
10621062
Additional keyword arguments passed to numpy.einsum or
10631063
dask.array.einsum
@@ -1070,7 +1070,7 @@ def dot(*arrays, dims=None, **kwargs):
10701070
--------
10711071
10721072
>>> import numpy as np
1073-
>>> import xarray as xp
1073+
>>> import xarray as xr
10741074
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=['a', 'b'])
10751075
>>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2),
10761076
... dims=['a', 'b', 'c'])
@@ -1117,6 +1117,14 @@ def dot(*arrays, dims=None, **kwargs):
11171117
[273, 446, 619]])
11181118
Dimensions without coordinates: a, d
11191119
1120+
>>> xr.dot(da_a, da_b)
1121+
<xarray.DataArray (c: 2)>
1122+
array([110, 125])
1123+
Dimensions without coordinates: c
1124+
1125+
>>> xr.dot(da_a, da_b, dims=...)
1126+
<xarray.DataArray ()>
1127+
array(235)
11201128
"""
11211129
from .dataarray import DataArray
11221130
from .variable import Variable
@@ -1141,7 +1149,9 @@ def dot(*arrays, dims=None, **kwargs):
11411149
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
11421150
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
11431151

1144-
if dims is None:
1152+
if dims is ...:
1153+
dims = all_dims
1154+
elif dims is None:
11451155
# find dimensions that occur more than one times
11461156
dim_counts = Counter()
11471157
for arr in arrays:

xarray/core/dataarray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,9 +2742,9 @@ def dot(
27422742
----------
27432743
other : DataArray
27442744
The other array with which the dot product is performed.
2745-
dims: hashable or sequence of hashables, optional
2746-
Along which dimensions to be summed over. Default all the common
2747-
dimensions are summed over.
2745+
dims: '...', hashable or sequence of hashables, optional
2746+
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
2747+
If not specified, then all the common dimensions are summed over.
27482748
27492749
Returns
27502750
-------

xarray/tests/test_computation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,23 @@ def test_dot(use_dask):
998998
assert actual.dims == ("b",)
999999
assert (actual.data == np.zeros(actual.shape)).all()
10001000

1001+
# Ellipsis (...) sums over all dimensions
1002+
actual = xr.dot(da_a, da_b, dims=...)
1003+
assert actual.dims == ()
1004+
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()
1005+
1006+
actual = xr.dot(da_a, da_b, da_c, dims=...)
1007+
assert actual.dims == ()
1008+
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()
1009+
1010+
actual = xr.dot(da_a, dims=...)
1011+
assert actual.dims == ()
1012+
assert (actual.data == np.einsum("ij-> ", a)).all()
1013+
1014+
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
1015+
assert actual.dims == ()
1016+
assert (actual.data == np.zeros(actual.shape)).all()
1017+
10011018
# Invalid cases
10021019
if not use_dask:
10031020
with pytest.raises(TypeError):

xarray/tests/test_dataarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3925,6 +3925,16 @@ def test_dot(self):
39253925
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
39263926
assert_equal(expected, actual)
39273927

3928+
# Ellipsis: all dims are shared
3929+
actual = da.dot(da, dims=...)
3930+
expected = da.dot(da)
3931+
assert_equal(expected, actual)
3932+
3933+
# Ellipsis: not all dims are shared
3934+
actual = da.dot(dm, dims=...)
3935+
expected = da.dot(dm, dims=("j", "x", "y", "z"))
3936+
assert_equal(expected, actual)
3937+
39283938
with pytest.raises(NotImplementedError):
39293939
da.dot(dm.to_dataset(name="dm"))
39303940
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)