Skip to content

Feature request: Implement interp for interpolating between chunks of data (dask) #4078

Closed
@pums974

Description

@pums974

In a project of mine I need to interpolate a dask-based xarray between chunk of data.

I made it work using monkey patching. I'm pretty sure that you can write it better but I made it as good as I could.

I hope that what I wrote can help you implement it properly.

from typing import Union, Tuple, Callable, Any, List

import dask.array as da
import numpy as np
import xarray as xr
import xarray.core.missing as m

def interp_func(var: Union[np.ndarray, da.Array],
                x: Tuple[xr.DataArray, ...],
                new_x: Tuple[xr.DataArray, ...],
                method: str,
                kwargs: Any) -> da.Array:
    """
    multi-dimensional interpolation for array-like. Interpolated axes should be
    located in the last position.

    Parameters
    ----------
    var: np.ndarray or dask.array.Array
        Array to be interpolated. The final dimension is interpolated.
    x: a list of 1d array.
        Original coordinates. Should not contain NaN.
    new_x: a list of 1d array
        New coordinates. Should not contain NaN.
    method: string
        {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
        1-dimensional itnterpolation.
        {'linear', 'nearest'} for multidimensional interpolation
    **kwargs:
        Optional keyword arguments to be passed to scipy.interpolator

    Returns
    -------
    interpolated: array
        Interpolated array

    Note
    ----
    This requiers scipy installed.

    See Also
    --------
    scipy.interpolate.interp1d
    """

    try:
        # try the official interp_func first
        res = official_interp_func(var, x, new_x, method, kwargs)
        return res
    except NotImplementedError:
        # may fail if interpolating between chunks
        pass

    if len(x) == 1:
        func, _kwargs = m._get_interpolator(method, vectorizeable_only=True,
                                            **kwargs)
    else:
        func, _kwargs = m._get_interpolator_nd(method, **kwargs)

    # reshape new_x (TODO REMOVE ?)
    current_dims = [_x.name for _x in x]
    new_x = tuple([_x.set_dims(current_dims) for _x in new_x])

    # number of non interpolated dimensions
    nconst = var.ndim - len(x)

    # duplicate the ghost cells of the array
    bnd = {i: "none" for i in range(len(var.shape))}
    depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))}
    var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd)

    # chunks x and duplicate the ghost cells of x
    x = tuple(da.from_array(_x, chunks=chunks) for _x, chunks in zip(x, var.chunks[nconst:]))
    x_with_ghost = tuple(da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"})
                         for _x in x)

    # compute final chunks
    chunks_end = [np.cumsum(sizes) - 1 for _x in x
                                       for sizes in _x.chunks]
    chunks_end_with_ghost = [np.cumsum(sizes) - 1 for _x in x_with_ghost
                                                  for sizes in _x.chunks]
    total_chunks = []
    for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)):
        l_new_x_ends: List[np.ndarray] = []
        for iend, iend_with_ghost in zip(*ce):

            arr = np.moveaxis(new_x[dim].data, dim, -1)
            arr = arr[tuple([0] * (len(arr.shape) - 1))]

            n_no_ghost = (arr <= x[dim][iend]).sum()
            n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum()

            equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int)

            l_new_x_ends.append(equil)

        new_x_ends = np.array(l_new_x_ends)
        chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1])
        total_chunks.append(tuple(chunks))
    final_chunks = var.chunks[:-len(x)] + tuple(total_chunks)

    # chunks new_x
    new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x)

    # reshape x_with_ghost (TODO REMOVE ?)
    x_with_ghost = da.meshgrid(*x_with_ghost, indexing='ij')

    # compute on chunks (TODO use drop_axis and new_axis ?)
    res = da.map_blocks(_myinterpnd, var_with_ghost, func, _kwargs, len(x_with_ghost), *x_with_ghost, *new_x,
                        dtype=var.dtype, chunks=final_chunks)

    # reshape res and remove empty chunks (TODO REMOVE ?)
    res = res.squeeze()
    new_chunks = tuple([tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks])
    res = res.rechunk(new_chunks)
    return res


def _myinterpnd(var: da.Array,
                func: Callable[..., Any],
                kwargs: Any,
                nx: int,
                *arrs: da.Array) -> da.Array:
    _old_x, _new_x = arrs[:nx], arrs[nx:]

    # reshape x (TODO REMOVE ?)
    old_x = tuple([np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))]
                   for dim, tmp in enumerate(_old_x)])

    new_x = tuple([xr.DataArray(_x) for _x in _new_x])

    return m._interpnd(var, old_x, new_x, func, kwargs)


official_interp_func = m.interp_func
m.interp_func = interp_func

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions