diff --git a/doc/api.rst b/doc/api.rst index 0e34b13f9a5..0f3f8ae2861 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,6 +19,7 @@ Top-level functions apply_ufunc align broadcast + call_on_dataset concat merge combine_by_coords diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c05fb9cd555..6da2780e307 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,9 @@ v0.19.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:func:`call_on_dataset` as a way to apply functions expecting + :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). + By `Justus Magin `_. Breaking changes diff --git a/xarray/__init__.py b/xarray/__init__.py index 8321aba4b46..8583f30b3c8 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,7 +18,16 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where +from .core.computation import ( + apply_ufunc, + call_on_dataset, + corr, + cov, + dot, + polyval, + unify_chunks, + where, +) from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -46,6 +55,7 @@ # Top-level functions "align", "apply_ufunc", + "call_on_dataset", "as_variable", "broadcast", "cftime_range", diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cd9e22d90db..6236ba0c237 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1187,6 +1187,94 @@ def apply_ufunc( return apply_array_ufunc(func, *args, dask=dask) +def call_on_dataset(func, obj, name, *args, **kwargs): + """apply a function expecting a Dataset to a xarray object + + Parameters + ---------- + func : callable + A function expecting a Dataset as its first parameter. + obj : DataArray or Dataset + The dataset to apply ``func`` to. If a ``DataArray``, convert it to a single + variable ``Dataset`` first. + name : hashable + A intermediate name to use as the name of the data variable. If the DataArray + already had a name, it will be restored after converting back. + *args, **kwargs + Additional arguments to ``func`` + + Returns + ------- + DataArray or Dataset + The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``. + + See Also + -------- + Dataset.map + Dataset.pipe + DataArray.pipe + + Examples + -------- + >>> def f(ds): + ... return xr.Dataset( + ... { + ... name: var * var.attrs.get("scale", 1) + ... for name, var in ds.data_vars.items() + ... }, + ... coords=ds.coords, + ... attrs=ds.attrs, + ... ) + ... + >>> ds = xr.Dataset( + ... {"a": ("x", [3, 4], {"scale": 0.5}), "b": ("x", [-1, 1], {"scale": 1.5})}, + ... coords={"x": [0, 1]}, + ... attrs={"attr": "value"}, + ... ) + >>> ds + + Dimensions: (x: 2) + Coordinates: + * x (x) int64 0 1 + Data variables: + a (x) int64 3 4 + b (x) int64 -1 1 + Attributes: + attr: value + >>> xr.call_on_dataset(f, ds, name="") + + Dimensions: (x: 2) + Coordinates: + * x (x) int64 0 1 + Data variables: + a (x) float64 1.5 2.0 + b (x) float64 -1.5 1.5 + Attributes: + attr: value + >>> xr.call_on_dataset(f, ds.a, name="") + + array([1.5, 2. ]) + Coordinates: + * x (x) int64 0 1 + >>> xr.call_on_dataset(lambda ds: list(ds.variables.keys()), ds.a, name="data") + ['x', 'data'] + """ + from .dataarray import DataArray, Dataset + from .parallel import dataset_to_dataarray + + if isinstance(obj, DataArray): + ds = obj.to_dataset(name=name) + else: + ds = obj + + result = func(ds, *args, **kwargs) + + if isinstance(obj, DataArray) and isinstance(result, Dataset): + result = dataset_to_dataarray(result).rename(obj.name) + + return result + + def cov(da_a, da_b, dim=None, ddof=1): """ Compute covariance between two DataArray objects along a shared dimension. diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 2439ea30b4b..95618d21c86 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -469,6 +469,67 @@ def test_apply_groupby_add(): add(data_array.groupby("y"), data_array.groupby("x")) +@pytest.mark.parametrize( + ["obj", "attrs", "expected"], + ( + pytest.param( + xr.DataArray( + [0, 1], + dims="x", + ), + {None: {"a": 1}}, + xr.DataArray([0, 1], dims="x", attrs={"a": 1}), + id="unnamed DataArray", + ), + pytest.param( + xr.DataArray( + [0, 1], + dims="x", + name="b", + ), + {None: {"a": 1}}, + xr.DataArray([0, 1], dims="x", attrs={"a": 1}, name="b"), + id="named DataArray", + ), + pytest.param( + xr.Dataset( + {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, + coords={ + "x": ("x", [-1, 1]), + "u": ("x", [2, 3]), + }, + ), + {"a": {"a": 1}, "u": {"b": 2}}, + xr.Dataset( + {"a": ("x", [1, 2], {"a": 1}), "b": ("x", [0, 1])}, + coords={"x": [-1, 1], "u": ("x", [2, 3], {"b": 2})}, + ), + id="Dataset", + ), + ), +) +def test_call_on_dataset(obj, attrs, expected): + temporary_name = "" + + def attach_attrs(ds, attrs): + new_ds = ds.copy() + for n, v in new_ds.variables.items(): + if n == temporary_name: + n = None + + if n not in attrs: + continue + + v.attrs.update(attrs[n]) + + return new_ds + + actual = xr.call_on_dataset( + lambda ds: attach_attrs(ds, attrs), obj, name=temporary_name + ) + assert_identical(actual, expected) + + def test_unified_dim_sizes(): assert unified_dim_sizes([xr.Variable((), 0)]) == {} assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1}