diff --git a/doc/api.rst b/doc/api.rst index 256a1dbf3af..40f9add3c57 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -30,6 +30,7 @@ Top-level functions zeros_like ones_like dot + map_blocks Dataset ======= @@ -499,6 +500,8 @@ Dataset methods Dataset.persist Dataset.load Dataset.chunk + Dataset.unify_chunks + Dataset.map_blocks Dataset.filter_by_attrs Dataset.info @@ -529,6 +532,8 @@ DataArray methods DataArray.persist DataArray.load DataArray.chunk + DataArray.unify_chunks + DataArray.map_blocks GroupBy objects =============== @@ -629,6 +634,7 @@ Testing testing.assert_equal testing.assert_identical testing.assert_allclose + testing.assert_chunks_equal Exceptions ========== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 520f1f79870..b6e12f01f4b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,11 @@ Breaking changes New functions/methods ~~~~~~~~~~~~~~~~~~~~~ +- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`. + Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and + :py:meth:`testing.assert_chunks_equal`. By `Deepak Cherian `_ + and `Guido Imperiale `_. + Enhancements ~~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index cdca708e28c..394dd0f80bc 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -17,6 +17,7 @@ from .core.dataarray import DataArray from .core.merge import merge, MergeError from .core.options import set_options +from .core.parallel import map_blocks from .backends.api import ( open_dataset, diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 1508fb50b38..ed6908117a2 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,7 +22,6 @@ unpack_for_encoding, ) - # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py new file mode 100644 index 00000000000..c3dbdd27098 --- /dev/null +++ b/xarray/core/dask_array_compat.py @@ -0,0 +1,91 @@ +from distutils.version import LooseVersion + +import dask.array as da +import numpy as np +from dask import __version__ as dask_version + +if LooseVersion(dask_version) >= LooseVersion("2.0.0"): + meta_from_array = da.utils.meta_from_array +else: + # Copied from dask v2.4.0 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import numbers + + def meta_from_array(x, ndim=None, dtype=None): + """ Normalize an array to appropriate meta object + + Parameters + ---------- + x: array-like, callable + Either an object that looks sufficiently like a Numpy array, + or a callable that accepts shape and dtype keywords + ndim: int + Number of dimensions of the array + dtype: Numpy dtype + A valid input for ``np.dtype`` + + Returns + ------- + array-like with zero elements of the correct dtype + """ + # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) + # implement a _meta attribute that are incompatible with Dask Array._meta + if hasattr(x, "_meta") and isinstance(x, da.Array): + x = x._meta + + if dtype is None and x is None: + raise ValueError("You must specify the meta or dtype of the array") + + if np.isscalar(x): + x = np.array(x) + + if x is None: + x = np.ndarray + + if isinstance(x, type): + x = x(shape=(0,) * (ndim or 0), dtype=dtype) + + if ( + not hasattr(x, "shape") + or not hasattr(x, "dtype") + or not isinstance(x.shape, tuple) + ): + return x + + if isinstance(x, list) or isinstance(x, tuple): + ndims = [ + 0 + if isinstance(a, numbers.Number) + else a.ndim + if hasattr(a, "ndim") + else len(a) + for a in x + ] + a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)] + return a if isinstance(x, list) else tuple(x) + + if ndim is None: + ndim = x.ndim + + try: + meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))] + if meta.ndim != ndim: + if ndim > x.ndim: + meta = meta[ + (Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim)) + ] + meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))] + elif ndim == 0: + meta = meta.sum() + else: + meta = meta.reshape((0,) * ndim) + except Exception: + meta = np.empty((0,) * ndim, dtype=dtype or x.dtype) + + if np.isscalar(meta): + meta = np.array(meta) + + if dtype and meta.dtype != dtype: + meta = meta.astype(dtype) + + return meta diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d536d0de2c5..1b1d23bc2fc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -14,6 +14,7 @@ Optional, Sequence, Tuple, + TypeVar, Union, cast, overload, @@ -63,6 +64,8 @@ ) if TYPE_CHECKING: + T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) + try: from dask.delayed import Delayed except ImportError: @@ -3038,6 +3041,79 @@ def integrate( ds = self._to_temp_dataset().integrate(dim, datetime_unit) return self._from_temp_dataset(ds) + def unify_chunks(self) -> "DataArray": + """ Unify chunk size along all chunked dimensions of this DataArray. + + Returns + ------- + + DataArray with consistent chunk sizes for all dask-array variables + + See Also + -------- + + dask.array.core.unify_chunks + """ + ds = self._to_temp_dataset().unify_chunks() + return self._from_temp_dataset(ds) + + def map_blocks( + self, + func: "Callable[..., T_DSorDA]", + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, + ) -> "T_DSorDA": + """ + Apply a function to each chunk of this DataArray. This method is experimental + and its signature may change. + + Parameters + ---------- + func: callable + User-provided function that accepts a DataArray as its first parameter. The + function will receive a subset of this DataArray, corresponding to one chunk + along each chunked dimension. ``func`` will be executed as + ``func(obj_subset, *args, **kwargs)``. + + The function will be first run on mocked-up data, that looks like this array + but has sizes 0, to determine properties of the returned object such as + dtype, variable names, new dimensions and new indexes (if any). + + This function must return either a single DataArray or a single Dataset. + + This function cannot change size of existing dimensions, or add new chunked + dimensions. + args: Sequence + Passed verbatim to func after unpacking, after the sliced DataArray. xarray + objects, if any, will not be split by chunks. Passing dask collections is + not allowed. + kwargs: Mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. + + Returns + ------- + A single DataArray or Dataset with dask backend, reassembled from the outputs of + the function. + + Notes + ----- + This method is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, + it is recommended to use apply_ufunc. + + If none of the variables in this DataArray is backed by dask, calling this + method is equivalent to calling ``func(self, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + xarray.Dataset.map_blocks + """ + from .parallel import map_blocks + + return map_blocks(func, self, args, kwargs) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = property(StringAccessor) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7b4c7b441bd..42990df6f65 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -21,6 +21,7 @@ Sequence, Set, Tuple, + TypeVar, Union, cast, overload, @@ -85,6 +86,8 @@ from .dataarray import DataArray from .merge import CoercibleMapping + T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset") + try: from dask.delayed import Delayed except ImportError: @@ -1670,7 +1673,10 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: if v.chunks is not None: for dim, c in zip(v.dims, v.chunks): if dim in chunks and c != chunks[dim]: - raise ValueError("inconsistent chunks") + raise ValueError( + f"Object has inconsistent chunks along dimension {dim}. " + "This can be fixed by calling unify_chunks()." + ) chunks[dim] = c return Frozen(SortedKeysDict(chunks)) @@ -1855,7 +1861,7 @@ def isel( self, indexers: Mapping[Hashable, Any] = None, drop: bool = False, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed along the specified dimension(s). @@ -1938,7 +1944,7 @@ def sel( method: str = None, tolerance: Number = None, drop: bool = False, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -2011,7 +2017,7 @@ def sel( def head( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -2058,7 +2064,7 @@ def head( def tail( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -2108,7 +2114,7 @@ def tail( def thin( self, indexers: Union[Mapping[Hashable, int], int] = None, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed along every `n`th value for the specified dimension(s) @@ -2246,7 +2252,7 @@ def reindex( tolerance: Number = None, copy: bool = True, fill_value: Any = dtypes.NA, - **indexers_kwargs: Any + **indexers_kwargs: Any, ) -> "Dataset": """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2447,7 +2453,7 @@ def interp( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, - **coords_kwargs: Any + **coords_kwargs: Any, ) -> "Dataset": """ Multidimensional interpolation of Dataset. @@ -2673,7 +2679,7 @@ def rename( self, name_dict: Mapping[Hashable, Hashable] = None, inplace: bool = None, - **names: Hashable + **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -2876,7 +2882,7 @@ def expand_dims( self, dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, axis: Union[None, int, Sequence[int]] = None, - **dim_kwargs: Any + **dim_kwargs: Any, ) -> "Dataset": """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a @@ -3022,7 +3028,7 @@ def set_index( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, inplace: bool = None, - **indexes_kwargs: Union[Hashable, Sequence[Hashable]] + **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> "Dataset": """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -3124,7 +3130,7 @@ def reorder_levels( self, dim_order: Mapping[Hashable, Sequence[int]] = None, inplace: bool = None, - **dim_order_kwargs: Sequence[int] + **dim_order_kwargs: Sequence[int], ) -> "Dataset": """Rearrange index levels using input order. @@ -3190,7 +3196,7 @@ def _stack_once(self, dims, new_dim): def stack( self, dimensions: Mapping[Hashable, Sequence[Hashable]] = None, - **dimensions_kwargs: Sequence[Hashable] + **dimensions_kwargs: Sequence[Hashable], ) -> "Dataset": """ Stack any number of existing dimensions into a single new dimension. @@ -3891,7 +3897,7 @@ def interpolate_na( method: str = "linear", limit: int = None, use_coordinate: Union[bool, Hashable] = True, - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Interpolate values according to different methods. @@ -3942,7 +3948,7 @@ def interpolate_na( method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs + **kwargs, ) return new @@ -4023,7 +4029,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, allow_lazy: bool = False, - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Reduce this dataset by applying `func` along some dimension(s). @@ -4098,7 +4104,7 @@ def reduce( keep_attrs=keep_attrs, keepdims=keepdims, allow_lazy=allow_lazy, - **kwargs + **kwargs, ) coord_names = {k for k in self.coords if k in variables} @@ -4113,7 +4119,7 @@ def apply( func: Callable, keep_attrs: bool = None, args: Iterable[Any] = (), - **kwargs: Any + **kwargs: Any, ) -> "Dataset": """Apply a function over the data variables in this dataset. @@ -5366,5 +5372,108 @@ def filter_by_attrs(self, **kwargs): selection.append(var_name) return self[selection] + def unify_chunks(self) -> "Dataset": + """ Unify chunk size along all chunked dimensions of this Dataset. + + Returns + ------- + + Dataset with consistent chunk sizes for all dask-array variables + + See Also + -------- + + dask.array.core.unify_chunks + """ + + try: + self.chunks + except ValueError: # "inconsistent chunks" + pass + else: + # No variables with dask backend, or all chunks are already aligned + return self.copy() + + # import dask is placed after the quick exit test above to allow + # running this method if dask isn't installed and there are no chunks + import dask.array + + ds = self.copy() + + dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)} + + dask_array_names = [] + dask_unify_args = [] + for name, variable in ds.variables.items(): + if isinstance(variable.data, dask.array.Array): + dims_tuple = [dims_pos_map[dim] for dim in variable.dims] + dask_array_names.append(name) + dask_unify_args.append(variable.data) + dask_unify_args.append(dims_tuple) + + _, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args) + + for name, new_array in zip(dask_array_names, rechunked_arrays): + ds.variables[name]._data = new_array + + return ds + + def map_blocks( + self, + func: "Callable[..., T_DSorDA]", + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, + ) -> "T_DSorDA": + """ + Apply a function to each chunk of this Dataset. This method is experimental and + its signature may change. + + Parameters + ---------- + func: callable + User-provided function that accepts a Dataset as its first parameter. The + function will receive a subset of this Dataset, corresponding to one chunk + along each chunked dimension. ``func`` will be executed as + ``func(obj_subset, *args, **kwargs)``. + + The function will be first run on mocked-up data, that looks like this + Dataset but has sizes 0, to determine properties of the returned object such + as dtype, variable names, new dimensions and new indexes (if any). + + This function must return either a single DataArray or a single Dataset. + + This function cannot change size of existing dimensions, or add new chunked + dimensions. + args: Sequence + Passed verbatim to func after unpacking, after the sliced DataArray. xarray + objects, if any, will not be split by chunks. Passing dask collections is + not allowed. + kwargs: Mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. + + Returns + ------- + A single DataArray or Dataset with dask backend, reassembled from the outputs of + the function. + + Notes + ----- + This method is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, + it is recommended to use apply_ufunc. + + If none of the variables in this Dataset is backed by dask, calling this method + is equivalent to calling ``func(self, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + xarray.DataArray.map_blocks + """ + from .parallel import map_blocks + + return map_blocks(func, self, args, kwargs) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py new file mode 100644 index 00000000000..48bb9ccfc3d --- /dev/null +++ b/xarray/core/parallel.py @@ -0,0 +1,339 @@ +try: + import dask + import dask.array + from dask.highlevelgraph import HighLevelGraph + from .dask_array_compat import meta_from_array + +except ImportError: + pass + +import itertools +import operator +from typing import ( + Any, + Callable, + Dict, + Hashable, + Mapping, + Sequence, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +from .dataarray import DataArray +from .dataset import Dataset + +T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + + +def dataset_to_dataarray(obj: Dataset) -> DataArray: + if not isinstance(obj, Dataset): + raise TypeError("Expected Dataset, got %s" % type(obj)) + + if len(obj.data_vars) > 1: + raise TypeError( + "Trying to convert Dataset with more than one data variable to DataArray" + ) + + return next(iter(obj.data_vars.values())) + + +def make_meta(obj): + """If obj is a DataArray or Dataset, return a new object of the same type and with + the same variables and dtypes, but where all variables have size 0 and numpy + backend. + If obj is neither a DataArray nor Dataset, return it unaltered. + """ + if isinstance(obj, DataArray): + obj_array = obj + obj = obj._to_temp_dataset() + elif isinstance(obj, Dataset): + obj_array = None + else: + return obj + + meta = Dataset() + for name, variable in obj.variables.items(): + meta_obj = meta_from_array(variable.data, ndim=variable.ndim) + meta[name] = (variable.dims, meta_obj, variable.attrs) + meta.attrs = obj.attrs + meta = meta.set_coords(obj.coords) + + if obj_array is not None: + return obj_array._from_temp_dataset(meta) + return meta + + +def infer_template( + func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs +) -> T_DSorDA: + """Infer return object by running the function on meta objects. + """ + meta_args = [make_meta(arg) for arg in (obj,) + args] + + try: + template = func(*meta_args, **kwargs) + except Exception as e: + raise Exception( + "Cannot infer object returned from running user provided function." + ) from e + + if not isinstance(template, (Dataset, DataArray)): + raise TypeError( + "Function must return an xarray DataArray or Dataset. Instead it returned " + f"{type(template)}" + ) + + return template + + +def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: + """Map variable name to numpy(-like) data + (Dataset.to_dict() is too complicated). + """ + if isinstance(x, DataArray): + x = x._to_temp_dataset() + + return {k: v.data for k, v in x.variables.items()} + + +def map_blocks( + func: Callable[..., T_DSorDA], + obj: Union[DataArray, Dataset], + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] = None, +) -> T_DSorDA: + """Apply a function to each chunk of a DataArray or Dataset. This function is + experimental and its signature may change. + + Parameters + ---------- + func: callable + User-provided function that accepts a DataArray or Dataset as its first + parameter. The function will receive a subset of 'obj' (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(obj_subset, *args, **kwargs)``. + + The function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + + This function must return either a single DataArray or a single Dataset. + + This function cannot change size of existing dimensions, or add new chunked + dimensions. + obj: DataArray, Dataset + Passed to the function as its first argument, one dask chunk at a time. + args: Sequence + Passed verbatim to func after unpacking, after the sliced obj. xarray objects, + if any, will not be split by chunks. Passing dask collections is not allowed. + kwargs: Mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + split by chunks. Passing dask collections is not allowed. + + Returns + ------- + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. + + Notes + ----- + This function is designed for when one needs to manipulate a whole xarray object + within each chunk. In the more common case where one can work on numpy arrays, it is + recommended to use apply_ufunc. + + If none of the variables in obj is backed by dask, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + xarray.DataArray.map_blocks + """ + + def _wrapper(func, obj, to_array, args, kwargs): + if to_array: + obj = dataset_to_dataarray(obj) + + result = func(obj, *args, **kwargs) + + for name, index in result.indexes.items(): + if name in obj.indexes: + if len(index) != len(obj.indexes[name]): + raise ValueError( + "Length of the %r dimension has changed. This is not allowed." + % name + ) + + return make_dict(result) + + if not isinstance(args, Sequence): + raise TypeError("args must be a sequence (for example, a list or tuple).") + if kwargs is None: + kwargs = {} + elif not isinstance(kwargs, Mapping): + raise TypeError("kwargs must be a mapping (for example, a dict)") + + for value in list(args) + list(kwargs.values()): + if dask.is_dask_collection(value): + raise TypeError( + "Cannot pass dask collections in args or kwargs yet. Please compute or " + "load values before passing to map_blocks." + ) + + if not dask.is_dask_collection(obj): + return func(obj, *args, **kwargs) + + if isinstance(obj, DataArray): + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() + input_is_array = True + else: + dataset = obj + input_is_array = False + + input_chunks = dataset.chunks + + template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) + if isinstance(template, DataArray): + result_is_array = True + template_name = template.name + template = template._to_temp_dataset() + elif isinstance(template, Dataset): + result_is_array = False + else: + raise TypeError( + f"func output must be DataArray or Dataset; got {type(template)}" + ) + + template_indexes = set(template.indexes) + dataset_indexes = set(dataset.indexes) + preserved_indexes = template_indexes & dataset_indexes + new_indexes = template_indexes - dataset_indexes + indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + indexes.update({k: template.indexes[k] for k in new_indexes}) + + graph: Dict[Any, Any] = {} + gname = "%s-%s" % ( + dask.utils.funcname(func), + dask.base.tokenize(dataset, args, kwargs), + ) + + # map dims to list of chunk indexes + ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} + # mapping from chunk index to slice bounds + chunk_index_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() + } + + # iterate over all possible chunk combinations + for v in itertools.product(*ichunk.values()): + chunk_index_dict = dict(zip(dataset.dims, v)) + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # recursively index into dask_keys nested list to get chunk + chunk = variable.__dask_keys__() + for dim in variable.dims: + chunk = chunk[chunk_index_dict[dim]] + + chunk_variable_task = ("%s-%s" % (gname, chunk[0]),) + v + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + # non-dask array with possibly chunked dimensions + # index into variable appropriately + subsetter = {} + for dim in variable.dims: + if dim in chunk_index_dict: + which_chunk = chunk_index_dict[dim] + subsetter[dim] = slice( + chunk_index_bounds[dim][which_chunk], + chunk_index_bounds[dim][which_chunk + 1], + ) + + subset = variable.isel(subsetter) + chunk_variable_task = ( + "%s-%s" % (gname, dask.base.tokenize(subset)), + ) + v + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + from_wrapper = (gname,) + v + graph[from_wrapper] = ( + _wrapper, + func, + (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), + input_is_array, + args, + kwargs, + ) + + # mapping from variable name to dask graph key + var_key_map: Dict[Hashable, str] = {} + for name, variable in template.variables.items(): + if name in indexes: + continue + gname_l = "%s-%s" % (gname, name) + var_key_map[name] = gname_l + + key: Tuple[Any, ...] = (gname_l,) + for dim in variable.dims: + if dim in chunk_index_dict: + key += (chunk_index_dict[dim],) + else: + # unchunked dimensions in the input have one chunk in the result + key += (0,) + + graph[key] = (operator.getitem, from_wrapper, name) + + graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + + result = Dataset(coords=indexes, attrs=template.attrs) + for name, gname_l in var_key_map.items(): + dims = template[name].dims + var_chunks = [] + for dim in dims: + if dim in input_chunks: + var_chunks.append(input_chunks[dim]) + elif dim in indexes: + var_chunks.append((len(indexes[dim]),)) + + data = dask.array.Array( + graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype + ) + result[name] = (dims, data, template[name].attrs) + + result = result.set_coords(template._coord_names) + + if result_is_array: + da = dataset_to_dataarray(result) + da.name = template_name + return da # type: ignore + return result # type: ignore diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index 7591fff3abe..f2e4518e0dc 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -41,7 +41,6 @@ import pandas as pd - # allow ourselves to type checks for Panel even after it's removed if LooseVersion(pd.__version__) < "0.25.0": Panel = pd.Panel diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6d7a07c6791..24865d62666 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3,7 +3,7 @@ from collections import OrderedDict, defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Hashable, Mapping, Union, TypeVar +from typing import Any, Hashable, Mapping, TypeVar, Union import numpy as np import pandas as pd diff --git a/xarray/testing.py b/xarray/testing.py index f01cbe896b9..95e41cfb10c 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -142,6 +142,26 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): raise TypeError("{} not supported by assertion comparison".format(type(a))) +def assert_chunks_equal(a, b): + """ + Assert that chunksizes along chunked dimensions are equal. + + Parameters + ---------- + a : xarray.Dataset or xarray.DataArray + The first object to compare. + b : xarray.Dataset or xarray.DataArray + The second object to compare. + """ + + if isinstance(a, DataArray) != isinstance(b, DataArray): + raise TypeError("a and b have mismatched types") + + left = a.unify_chunks() + right = b.unify_chunks() + assert left.chunks == right.chunks + + def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): assert isinstance(indexes, OrderedDict), indexes assert all(isinstance(v, pd.Index) for v in indexes.values()), { diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 8b4d3073e1c..acf8b67effa 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -9,6 +9,7 @@ import numpy as np import pytest from numpy.testing import assert_array_equal # noqa: F401 +from pandas.testing import assert_frame_equal # noqa: F401 import xarray.testing from xarray.core import utils @@ -17,8 +18,6 @@ from xarray.core.options import set_options from xarray.plot.utils import import_seaborn -from pandas.testing import assert_frame_equal # noqa: F401 - # import mpl and change the backend before other mpl imports try: import matplotlib as mpl diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0120e2ca0fe..b5421a6bc9f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -14,8 +14,8 @@ import numpy as np import pandas as pd -from pandas.errors import OutOfBoundsDatetime import pytest +from pandas.errors import OutOfBoundsDatetime import xarray as xr from xarray import ( diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 406b9c1ba69..33a409e6f45 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -6,7 +6,6 @@ import pytest from pandas.errors import OutOfBoundsDatetime - from xarray import DataArray, Dataset, Variable, coding, decode_cf from xarray.coding.times import ( _import_cftime, @@ -30,7 +29,6 @@ requires_cftime_or_netCDF4, ) - _NON_STANDARD_CALENDARS_SET = { "noleap", "365_day", diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c142ca7643b..3e2e132825f 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,3 +1,4 @@ +import operator import pickle from collections import OrderedDict from contextlib import suppress @@ -11,6 +12,7 @@ import xarray as xr import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable +from xarray.testing import assert_chunks_equal from xarray.tests import mock from . import ( @@ -894,3 +896,243 @@ def test_dask_layers_and_dependencies(): assert set(x.foo.__dask_graph__().dependencies).issuperset( ds.__dask_graph__().dependencies ) + + +def make_da(): + da = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(100, 120)}, + name="a", + ).chunk({"x": 4, "y": 5}) + da.attrs["test"] = "test" + da.coords["c2"] = 0.5 + da.coords["ndcoord"] = da.x * 2 + da.coords["cxy"] = (da.x * da.y).chunk({"x": 4, "y": 5}) + + return da + + +def make_ds(): + map_ds = xr.Dataset() + map_ds["a"] = make_da() + map_ds["b"] = map_ds.a + 50 + map_ds["c"] = map_ds.x + 20 + map_ds = map_ds.chunk({"x": 4, "y": 5}) + map_ds["d"] = ("z", [1, 1, 1, 1]) + map_ds["z"] = [0, 1, 2, 3] + map_ds["e"] = map_ds.x + map_ds.y + map_ds.coords["c1"] = 0.5 + map_ds.coords["cx"] = ("x", np.arange(len(map_ds.x))) + map_ds.coords["cx"].attrs["test2"] = "test2" + map_ds.attrs["test"] = "test" + map_ds.coords["xx"] = map_ds["a"] * map_ds.y + + return map_ds + + +# fixtures cannot be used in parametrize statements +# instead use this workaround +# https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly +@pytest.fixture +def map_da(): + return make_da() + + +@pytest.fixture +def map_ds(): + return make_ds() + + +def test_unify_chunks(map_ds): + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + + with raises_regex(ValueError, "inconsistent chunks"): + ds_copy.chunks + + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)} + with raise_if_dask_computes(): + actual_chunks = ds_copy.unify_chunks().chunks + expected_chunks == actual_chunks + assert_identical(map_ds, ds_copy.unify_chunks()) + + +@pytest.mark.parametrize("obj", [make_ds(), make_da()]) +@pytest.mark.parametrize( + "transform", [lambda x: x.compute(), lambda x: x.unify_chunks()] +) +def test_unify_chunks_shallow_copy(obj, transform): + obj = transform(obj) + unified = obj.unify_chunks() + assert_identical(obj, unified) and obj is not obj.unify_chunks() + + +def test_map_blocks_error(map_da, map_ds): + def bad_func(darray): + return (darray * darray.x + 5 * darray.y)[:1, :1] + + with raises_regex(ValueError, "Length of the.* has changed."): + xr.map_blocks(bad_func, map_da).compute() + + def returns_numpy(darray): + return (darray * darray.x + 5 * darray.y).values + + with raises_regex(TypeError, "Function must return an xarray DataArray"): + xr.map_blocks(returns_numpy, map_da) + + with raises_regex(TypeError, "args must be"): + xr.map_blocks(operator.add, map_da, args=10) + + with raises_regex(TypeError, "kwargs must be"): + xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) + + def really_bad_func(darray): + raise ValueError("couldn't do anything.") + + with raises_regex(Exception, "Cannot infer"): + xr.map_blocks(really_bad_func, map_da) + + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + + with raises_regex(ValueError, "inconsistent chunks"): + xr.map_blocks(bad_func, ds_copy) + + with raises_regex(TypeError, "Cannot pass dask collections"): + xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) + + with raises_regex(TypeError, "Cannot pass dask collections"): + xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks(obj): + def func(obj): + result = obj + obj.x + 5 * obj.y + return result + + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj) + expected = func(obj) + assert_chunks_equal(expected.chunk(), actual) + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_convert_args_to_list(obj): + expected = obj + 10 + with raise_if_dask_computes(): + actual = xr.map_blocks(operator.add, obj, [10]) + assert_chunks_equal(expected.chunk(), actual) + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_add_attrs(obj): + def add_attrs(obj): + obj = obj.copy(deep=True) + obj.attrs["new"] = "new" + obj.cxy.attrs["new2"] = "new2" + return obj + + expected = add_attrs(obj) + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj) + + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +def test_map_blocks_change_name(map_da): + def change_name(obj): + obj = obj.copy(deep=True) + obj.name = "new" + return obj + + expected = change_name(map_da) + with raise_if_dask_computes(): + actual = xr.map_blocks(change_name, map_da) + + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_kwargs(obj): + expected = xr.full_like(obj, fill_value=np.nan) + with raise_if_dask_computes(): + actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) + assert_chunks_equal(expected.chunk(), actual) + xr.testing.assert_identical(actual.compute(), expected.compute()) + + +def test_map_blocks_to_array(map_ds): + with raise_if_dask_computes(): + actual = xr.map_blocks(lambda x: x.to_array(), map_ds) + + # to_array does not preserve name, so cannot use assert_identical + assert_equal(actual.compute(), map_ds.to_array().compute()) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x.to_dataset(), + lambda x: x.drop("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.assign_coords(new_coord=("y", x.y * 2)), + lambda x: x.astype(np.int32), + # TODO: [lambda x: x.isel(x=1).drop("x"), map_da], + ], +) +def test_map_blocks_da_transformations(func, map_da): + with raise_if_dask_computes(): + actual = xr.map_blocks(func, map_da) + + assert_identical(actual.compute(), func(map_da).compute()) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x.drop("cxy"), + lambda x: x.drop("a"), + lambda x: x.drop("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.rename({"a": "new1", "b": "new2"}), + # TODO: [lambda x: x.isel(x=1)], + ], +) +def test_map_blocks_ds_transformations(func, map_ds): + with raise_if_dask_computes(): + actual = xr.map_blocks(func, map_ds) + + assert_identical(actual.compute(), func(map_ds).compute()) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_object_method(obj): + def func(obj): + result = obj + obj.x + 5 * obj.y + return result + + with raise_if_dask_computes(): + expected = xr.map_blocks(func, obj) + actual = obj.map_blocks(func) + + assert_identical(expected.compute(), actual.compute()) + + +def test_make_meta(map_ds): + from ..core.parallel import make_meta + + meta = make_meta(map_ds) + + for variable in map_ds._coord_names: + assert variable in meta._coord_names + assert meta.coords[variable].shape == (0,) * meta.coords[variable].ndim + + for variable in map_ds.data_vars: + assert variable in meta.data_vars + assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim