Skip to content

[FEATURE]: Add a replace method #6377

Open
@Huite

Description

@Huite

Is your feature request related to a problem?

If I have a DataArray of values:

da = xr.DataArray([0, 1, 2, 3, 4, 5])

And I'd like to replace to_replace=[1, 3, 5] by value=[10, 30, 50], there's no method da.replace(to_replace, value) to do this.

There's no easy way like pandas (https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.replace.html) to do this.

(Apologies if I've missed related issues, searching for "replace" gives many hits as the word is obviously used quite often.)

Describe the solution you'd like

da = xr.DataArray([0, 1, 2, 3, 4, 5])
replaced = da.replace([1, 3, 5], [10, 30, 50])
print(replaced)
<xarray.DataArray (dim_0: 6)>
array([ 0, 10,  2, 30,  4, 50])
Dimensions without coordinates: dim_0

I've had a try at a relatively efficient implementation below. I'm wondering whether it's a worthwhile addition to xarray?

Describe alternatives you've considered

Ignoring issues such as dealing with NaNs, chunks, etc., a simple dict lookup:

def dict_replace(da, to_replace, value):
    d = {k: v for k, v in zip(to_replace, value)}
    out = np.vectorize(lambda x: d.get(x, x))(da.values)
    return da.copy(data=out)

Alternatively, leveraging pandas:

def pandas_replace(da, to_replace, value):
    df = pd.DataFrame()
    df["values"] = da.values.ravel()
    df["values"].replace(to_replace, value, inplace=True)
    return da.copy(data=df["values"].values.reshape(da.shape))

But I also tried my hand at a custom implementation, letting np.unique do the heavy lifting:

def custom_replace(da, to_replace, value):
    # Use np.unique to create an inverse index
    flat = da.values.ravel()
    uniques, index = np.unique(flat, return_inverse=True)    
    replaceable = np.isin(flat, to_replace)

    # Create a replacement array in which there is a 1:1 relation between
    # uniques and the replacement values, so that we can use the inverse index
    # to select replacement values. 
    valid = np.isin(to_replace, uniques, assume_unique=True)
    # Remove to_replace values that are not present in da. If no overlap
    # exists between to_replace and the values in da, just return a copy.
    if not valid.any():
        return da.copy()
    to_replace = to_replace[valid]
    value = value[valid]

    replacement = np.zeros_like(uniques)
    replacement[np.searchsorted(uniques, to_replace)] = value

    out = flat.copy()
    out[replaceable] = replacement[index[replaceable]]
    return da.copy(data=out.reshape(da.shape))

Such an approach seems like it's consistently the fastest:

da = xr.DataArray(np.random.randint(0, 100, 100_000))
to_replace = np.random.choice(np.arange(100), 10, replace=False)
value = to_replace * 200

test1 = custom_replace(da, to_replace, value)
test2 = pandas_replace(da, to_replace, value)
test3 = dict_replace(da, to_replace, value)

assert test1.equals(test2)
assert test1.equals(test3)

# 6.93 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit custom_replace(da, to_replace, value) 

# 9.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit pandas_replace(da, to_replace, value) 

# 26.8 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dict_replace(da, to_replace, value) 

With the advantage growing the number of values involved:

da = xr.DataArray(np.random.randint(0, 10_000, 100_000))
to_replace = np.random.choice(np.arange(10_000), 10_000, replace=False)
value = to_replace * 200

test1 = custom_replace(da, to_replace, value)
test2 = pandas_replace(da, to_replace, value)
test3 = dict_replace(da, to_replace, value)

assert test1.equals(test2)
assert test1.equals(test3)


# 21.6 ms ± 990 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit custom_replace(da, to_replace, value)
# 3.12 s ± 574 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit pandas_replace(da, to_replace, value)
# 42.7 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dict_replace(da, to_replace, value)

In my real-life example, with a DataArray of approx 110 000 elements, with 60 000 values to replace, the custom one takes 33 ms, the dict one takes 135 ms, while pandas takes 26 s (!).

Additional context

In all cases, we need dealing with NaNs, checking the input, etc.:

def replace(da: xr.DataArray, to_replace: Any, value: Any):
    from xarray.core.utils import is_scalar

    if is_scalar(to_replace):
        if not is_scalar(value):
            raise TypeError("if to_replace is scalar, then value must be a scalar")
        if np.isnan(to_replace):
            return da.fillna(value) 
        else:
            return da.where(da != to_replace, other=value)
    else:
        to_replace = np.asarray(to_replace)
        if to_replace.ndim != 1:
            raise ValueError("to_replace must be 1D or scalar")
        if is_scalar(value):
            value = np.full_like(to_replace, value)
        else:
            value = np.asarray(value)
            if to_replace.shape != value.shape:
                raise ValueError(
                    f"Replacement arrays must match in shape. "
                    f"Expecting {to_replace.shape} got {value.shape} "
                )
    
    _, counts = np.unique(to_replace, return_counts=True)
    if (counts > 1).any():
        raise ValueError("to_replace contains duplicates")
    
    # Replace NaN values separately, as they will show up as separate values
    # from numpy.unique.
    isnan = np.isnan(to_replace)
    if isnan.any():
        i = np.nonzero(isnan)[0]
        da = da.fillna(value[i])

    # Use np.unique to create an inverse index
    flat = da.values.ravel()
    uniques, index = np.unique(flat, return_inverse=True)    
    replaceable = np.isin(flat, to_replace)

    # Create a replacement array in which there is a 1:1 relation between
    # uniques and the replacement values, so that we can use the inverse index
    # to select replacement values. 
    valid = np.isin(to_replace, uniques, assume_unique=True)
    # Remove to_replace values that are not present in da. If no overlap
    # exists between to_replace and the values in da, just return a copy.
    if not valid.any():
        return da.copy()
    to_replace = to_replace[valid]
    value = value[valid]

    replacement = np.zeros_like(uniques)
    replacement[np.searchsorted(uniques, to_replace)] = value

    out = flat.copy()
    out[replaceable] = replacement[index[replaceable]]
    return da.copy(data=out.reshape(da.shape))

It think it should be easy to use e.g. let it operate on the numpy arrays so e.g. apply_ufunc will work.
The primary issue is whether values can be sorted; in such a case the dict lookup might be an okay fallback?
I've had a peek at the pandas implementation, but didn't become much wiser.

Anyway, for your consideration! I'd be happy to submit a PR.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions