diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 14d7314e..231e1008 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -26,7 +26,12 @@ from .criteria import coordinate_criteria, regex from .helpers import bounds_to_vertices -from .utils import _is_datetime_like, invert_mappings, parse_cell_methods_attr +from .utils import ( + _is_datetime_like, + always_iterable, + invert_mappings, + parse_cell_methods_attr, +) #: Classes wrapped by cf_xarray. _WRAPPED_CLASSES = ( @@ -68,7 +73,7 @@ def apply_mapper( mappers: Union[Mapper, Tuple[Mapper, ...]], obj: Union[DataArray, Dataset], - key: str, + key: Any, error: bool = True, default: Any = None, ) -> List[Any]: @@ -79,8 +84,13 @@ def apply_mapper( It should return a list in all other cases including when there are no results for a good key. """ - if default is None: - default = [] + + if not isinstance(key, str): + if default is None: + raise ValueError("`default` must be provided when `key` is not a string.") + return list(always_iterable(default)) + + default = [] if default is None else list(always_iterable(default)) def _apply_single_mapper(mapper): @@ -917,8 +927,7 @@ def _rewrite_values( value = kwargs[key] mappers = all_mappers[key] - if isinstance(value, str): - value = [value] + value = always_iterable(value) if isinstance(value, dict): # this for things like isel where **kwargs captures things like T=5 diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index e89a4a25..1c337b30 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -317,8 +317,13 @@ def test_weighted(obj): with raise_if_dask_computes(max_computes=2): # weights are checked for nans expected = obj.weighted(obj["cell_area"]).sum("lat") - actual = obj.cf.weighted("area").sum("Y") - assert_identical(expected, actual) + actuals = [ + obj.cf.weighted("area").sum("Y"), + obj.cf.weighted(obj["cell_area"]).sum("Y"), + obj.cf.weighted(weights=obj["cell_area"]).sum("Y"), + ] + for actual in actuals: + assert_identical(expected, actual) @pytest.mark.parametrize("obj", objects) diff --git a/cf_xarray/utils.py b/cf_xarray/utils.py index e05e3501..70c61afd 100644 --- a/cf_xarray/utils.py +++ b/cf_xarray/utils.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict +from typing import Any, Dict, Iterable from xarray import DataArray @@ -53,3 +53,7 @@ def invert_mappings(*mappings): for name in v: merged[name] |= {k} return merged + + +def always_iterable(obj: Any) -> Iterable: + return [obj] if not isinstance(obj, (tuple, list, set, dict)) else obj diff --git a/doc/examples/introduction.ipynb b/doc/examples/introduction.ipynb index 1a2a38ef..a331c504 100644 --- a/doc/examples/introduction.ipynb +++ b/doc/examples/introduction.ipynb @@ -25,6 +25,22 @@ "import xarray as xr" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`cf_xarray` works best when `xarray` keeps attributes by default.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xr.set_options(keep_attrs=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -914,6 +930,7 @@ " * 110e3\n", ")\n", "# and set proper attributes\n", + "ds[\"cell_area\"].attrs = dict(standard_name=\"cell_area\", units=\"m2\")\n", "ds.air.attrs[\"cell_measures\"] = \"area: cell_area\"" ] }, @@ -1000,7 +1017,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.8.10" }, "toc": { "base_numbering": 1,