Skip to content

Commit a36d0a1

Browse files
keewisshoyer
andauthored
per-variable fill values (#4237)
* implement the fill_value mapping * get per-variable fill_values to work in DataArray.reindex * Update xarray/core/dataarray.py Co-authored-by: Stephan Hoyer <[email protected]> * check that the default value is used * check that merge works with multiple fill values * check that concat works with multiple fill values * check that combine_nested works with multiple fill values * check that Dataset.reindex and DataArray.reindex work * check that aligning Datasets works * check that Dataset.unstack works * allow passing multiple fill values to full_like with datasets * also allow overriding the dtype by variable * document the dict fill values in Dataset.reindex * document the changes to DataArray.reindex * document the changes to unstack * document the changes to align * document the changes to concat and merge * document the changes to Dataset.shift * document the changes to combine_* Co-authored-by: Stephan Hoyer <[email protected]>
1 parent 1a11d24 commit a36d0a1

File tree

12 files changed

+284
-81
lines changed

12 files changed

+284
-81
lines changed

xarray/core/alignment.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def align(
103103
used in preference to the aligned indexes.
104104
exclude : sequence of str, optional
105105
Dimensions that must be excluded from alignment
106-
fill_value : scalar, optional
107-
Value to use for newly missing values
106+
fill_value : scalar or dict-like, optional
107+
Value to use for newly missing values. If a dict-like, maps
108+
variable names to fill values. Use a data array's name to
109+
refer to its values.
108110
109111
Returns
110112
-------
@@ -581,16 +583,21 @@ def reindex_variables(
581583

582584
for name, var in variables.items():
583585
if name not in indexers:
586+
if isinstance(fill_value, dict):
587+
fill_value_ = fill_value.get(name, dtypes.NA)
588+
else:
589+
fill_value_ = fill_value
590+
584591
if sparse:
585-
var = var._as_sparse(fill_value=fill_value)
592+
var = var._as_sparse(fill_value=fill_value_)
586593
key = tuple(
587594
slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None))
588595
for d in var.dims
589596
)
590597
needs_masking = any(d in masked_dims for d in var.dims)
591598

592599
if needs_masking:
593-
new_var = var._getitem_with_mask(key, fill_value=fill_value)
600+
new_var = var._getitem_with_mask(key, fill_value=fill_value_)
594601
elif all(is_full_slice(k) for k in key):
595602
# no reindexing necessary
596603
# here we need to manually deal with copying data, since

xarray/core/combine.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,10 @@ def combine_nested(
393393
Details are in the documentation of concat
394394
coords : {"minimal", "different", "all" or list of str}, optional
395395
Details are in the documentation of concat
396-
fill_value : scalar, optional
397-
Value to use for newly missing values
396+
fill_value : scalar or dict-like, optional
397+
Value to use for newly missing values. If a dict-like, maps
398+
variable names to fill values. Use a data array's name to
399+
refer to its values.
398400
join : {"outer", "inner", "left", "right", "exact"}, optional
399401
String indicating how to combine differing indexes
400402
(excluding concat_dim) in objects
@@ -569,10 +571,12 @@ def combine_by_coords(
569571
addition to the "minimal" data variables.
570572
571573
If objects are DataArrays, `data_vars` must be "all".
572-
coords : {"minimal", "different", "all" or list of str}, optional
574+
coords : {"minimal", "different", "all"} or list of str, optional
573575
As per the "data_vars" kwarg, but for coordinate variables.
574-
fill_value : scalar, optional
575-
Value to use for newly missing values. If None, raises a ValueError if
576+
fill_value : scalar or dict-like, optional
577+
Value to use for newly missing values. If a dict-like, maps
578+
variable names to fill values. Use a data array's name to
579+
refer to its values. If None, raises a ValueError if
576580
the passed Datasets do not create a complete hypercube.
577581
join : {"outer", "inner", "left", "right", "exact"}, optional
578582
String indicating how to combine differing indexes

xarray/core/common.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,10 +1364,13 @@ def full_like(other, fill_value, dtype: DTypeLike = None):
13641364
----------
13651365
other : DataArray, Dataset or Variable
13661366
The reference object in input
1367-
fill_value : scalar
1368-
Value to fill the new object with before returning it.
1369-
dtype : dtype, optional
1370-
dtype of the new array. If omitted, it defaults to other.dtype.
1367+
fill_value : scalar or dict-like
1368+
Value to fill the new object with before returning it. If
1369+
other is a Dataset, may also be a dict-like mapping data
1370+
variables to fill values.
1371+
dtype : dtype or dict-like of dtype, optional
1372+
dtype of the new array. If a dict-like, maps dtypes to
1373+
variables. If omitted, it defaults to other.dtype.
13711374
13721375
Returns
13731376
-------
@@ -1427,6 +1430,34 @@ def full_like(other, fill_value, dtype: DTypeLike = None):
14271430
* lat (lat) int64 1 2
14281431
* lon (lon) int64 0 1 2
14291432
1433+
>>> ds = xr.Dataset(
1434+
... {"a": ("x", [3, 5, 2]), "b": ("x", [9, 1, 0])}, coords={"x": [2, 4, 6]}
1435+
... )
1436+
>>> ds
1437+
<xarray.Dataset>
1438+
Dimensions: (x: 3)
1439+
Coordinates:
1440+
* x (x) int64 2 4 6
1441+
Data variables:
1442+
a (x) int64 3 5 2
1443+
b (x) int64 9 1 0
1444+
>>> xr.full_like(ds, fill_value={"a": 1, "b": 2})
1445+
<xarray.Dataset>
1446+
Dimensions: (x: 3)
1447+
Coordinates:
1448+
* x (x) int64 2 4 6
1449+
Data variables:
1450+
a (x) int64 1 1 1
1451+
b (x) int64 2 2 2
1452+
>>> xr.full_like(ds, fill_value={"a": 1, "b": 2}, dtype={"a": bool, "b": float})
1453+
<xarray.Dataset>
1454+
Dimensions: (x: 3)
1455+
Coordinates:
1456+
* x (x) int64 2 4 6
1457+
Data variables:
1458+
a (x) bool True True True
1459+
b (x) float64 2.0 2.0 2.0
1460+
14301461
See also
14311462
--------
14321463
@@ -1438,12 +1469,22 @@ def full_like(other, fill_value, dtype: DTypeLike = None):
14381469
from .dataset import Dataset
14391470
from .variable import Variable
14401471

1441-
if not is_scalar(fill_value):
1442-
raise ValueError(f"fill_value must be scalar. Received {fill_value} instead.")
1472+
if not is_scalar(fill_value) and not (
1473+
isinstance(other, Dataset) and isinstance(fill_value, dict)
1474+
):
1475+
raise ValueError(
1476+
f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
1477+
)
14431478

14441479
if isinstance(other, Dataset):
1480+
if not isinstance(fill_value, dict):
1481+
fill_value = {k: fill_value for k in other.data_vars.keys()}
1482+
1483+
if not isinstance(dtype, dict):
1484+
dtype = {k: dtype for k in other.data_vars.keys()}
1485+
14451486
data_vars = {
1446-
k: _full_like_variable(v, fill_value, dtype)
1487+
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype.get(k, None))
14471488
for k, v in other.data_vars.items()
14481489
}
14491490
return Dataset(data_vars, coords=other.coords, attrs=other.attrs)
@@ -1466,6 +1507,9 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
14661507
"""
14671508
from .variable import Variable
14681509

1510+
if fill_value is dtypes.NA:
1511+
fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype)
1512+
14691513
if isinstance(other.data, dask_array_type):
14701514
import dask.array
14711515

xarray/core/concat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ def concat(
125125
List of integer arrays which specifies the integer positions to which
126126
to assign each dataset along the concatenated dimension. If not
127127
supplied, objects are concatenated in the provided order.
128-
fill_value : scalar, optional
129-
Value to use for newly missing values
128+
fill_value : scalar or dict-like, optional
129+
Value to use for newly missing values. If a dict-like, maps
130+
variable names to fill values. Use a data array's name to
131+
refer to its values.
130132
join : {"outer", "inner", "left", "right", "exact"}, optional
131133
String indicating how to combine differing indexes
132134
(excluding dim) in objects

xarray/core/dataarray.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,10 @@ def reindex_like(
13081308
``copy=False`` and reindexing is unnecessary, or can be performed
13091309
with only slice operations, then the output may share memory with
13101310
the input. In either case, a new xarray object is always returned.
1311-
fill_value : scalar, optional
1312-
Value to use for newly missing values
1311+
fill_value : scalar or dict-like, optional
1312+
Value to use for newly missing values. If a dict-like, maps
1313+
variable names (including coordinates) to fill values. Use this
1314+
data array's name to refer to the data array's values.
13131315
13141316
Returns
13151317
-------
@@ -1368,8 +1370,10 @@ def reindex(
13681370
Maximum distance between original and new labels for inexact
13691371
matches. The values of the index at the matching locations must
13701372
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
1371-
fill_value : scalar, optional
1372-
Value to use for newly missing values
1373+
fill_value : scalar or dict-like, optional
1374+
Value to use for newly missing values. If a dict-like, maps
1375+
variable names (including coordinates) to fill values. Use this
1376+
data array's name to refer to the data array's values.
13731377
**indexers_kwargs : {dim: indexer, ...}, optional
13741378
The keyword arguments form of ``indexers``.
13751379
One of indexers or indexers_kwargs must be provided.
@@ -1386,6 +1390,13 @@ def reindex(
13861390
align
13871391
"""
13881392
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
1393+
if isinstance(fill_value, dict):
1394+
fill_value = fill_value.copy()
1395+
sentinel = object()
1396+
value = fill_value.pop(self.name, sentinel)
1397+
if value is not sentinel:
1398+
fill_value[_THIS_ARRAY] = value
1399+
13891400
ds = self._to_temp_dataset().reindex(
13901401
indexers=indexers,
13911402
method=method,
@@ -1867,8 +1878,11 @@ def unstack(
18671878
dim : hashable or sequence of hashable, optional
18681879
Dimension(s) over which to unstack. By default unstacks all
18691880
MultiIndexes.
1870-
fill_value : scalar, default: nan
1871-
value to be filled.
1881+
fill_value : scalar or dict-like, default: nan
1882+
value to be filled. If a dict-like, maps variable names to
1883+
fill values. Use the data array's name to refer to its
1884+
name. If not provided or if the dict-like does not contain
1885+
all variables, the dtype's NA value will be used.
18721886
sparse : bool, default: False
18731887
use sparse-array if True
18741888

xarray/core/dataset.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,8 +2313,9 @@ def reindex_like(
23132313
``copy=False`` and reindexing is unnecessary, or can be performed
23142314
with only slice operations, then the output may share memory with
23152315
the input. In either case, a new xarray object is always returned.
2316-
fill_value : scalar, optional
2317-
Value to use for newly missing values
2316+
fill_value : scalar or dict-like, optional
2317+
Value to use for newly missing values. If a dict-like maps
2318+
variable names to fill values.
23182319
23192320
Returns
23202321
-------
@@ -2373,8 +2374,9 @@ def reindex(
23732374
``copy=False`` and reindexing is unnecessary, or can be performed
23742375
with only slice operations, then the output may share memory with
23752376
the input. In either case, a new xarray object is always returned.
2376-
fill_value : scalar, optional
2377-
Value to use for newly missing values
2377+
fill_value : scalar or dict-like, optional
2378+
Value to use for newly missing values. If a dict-like,
2379+
maps variable names (including coordinates) to fill values.
23782380
sparse : bool, default: False
23792381
use sparse-array.
23802382
**indexers_kwargs : {dim: indexer, ...}, optional
@@ -2441,6 +2443,19 @@ def reindex(
24412443
temperature (station) float64 18.84 0.0 19.22 0.0
24422444
pressure (station) float64 324.1 0.0 122.8 0.0
24432445
2446+
We can also use different fill values for each variable.
2447+
2448+
>>> x.reindex(
2449+
... {"station": new_index}, fill_value={"temperature": 0, "pressure": 100}
2450+
... )
2451+
<xarray.Dataset>
2452+
Dimensions: (station: 4)
2453+
Coordinates:
2454+
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
2455+
Data variables:
2456+
temperature (station) float64 18.84 0.0 19.22 0.0
2457+
pressure (station) float64 324.1 100.0 122.8 100.0
2458+
24442459
Because the index is not monotonically increasing or decreasing, we cannot use arguments
24452460
to the keyword method to fill the `NaN` values.
24462461
@@ -3544,8 +3559,10 @@ def unstack(
35443559
dim : hashable or iterable of hashable, optional
35453560
Dimension(s) over which to unstack. By default unstacks all
35463561
MultiIndexes.
3547-
fill_value : scalar, default: nan
3548-
value to be filled
3562+
fill_value : scalar or dict-like, default: nan
3563+
value to be filled. If a dict-like, maps variable names to
3564+
fill values. If not provided or if the dict-like does not
3565+
contain all variables, the dtype's NA value will be used.
35493566
sparse : bool, default: False
35503567
use sparse-array if True
35513568
@@ -3663,8 +3680,9 @@ def merge(
36633680
- 'left': use indexes from ``self``
36643681
- 'right': use indexes from ``other``
36653682
- 'exact': error instead of aligning non-equal indexes
3666-
fill_value : scalar, optional
3667-
Value to use for newly missing values
3683+
fill_value : scalar or dict-like, optional
3684+
Value to use for newly missing values. If a dict-like, maps
3685+
variable names (including coordinates) to fill values.
36683686
36693687
Returns
36703688
-------
@@ -5117,8 +5135,9 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
51175135
Integer offset to shift along each of the given dimensions.
51185136
Positive offsets shift to the right; negative offsets shift to the
51195137
left.
5120-
fill_value : scalar, optional
5121-
Value to use for newly missing values
5138+
fill_value : scalar or dict-like, optional
5139+
Value to use for newly missing values. If a dict-like, maps
5140+
variable names (including coordinates) to fill values.
51225141
**shifts_kwargs
51235142
The keyword arguments form of ``shifts``.
51245143
One of shifts or shifts_kwargs must be provided.
@@ -5153,8 +5172,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
51535172
variables = {}
51545173
for name, var in self.variables.items():
51555174
if name in self.data_vars:
5175+
fill_value_ = (
5176+
fill_value.get(name, dtypes.NA)
5177+
if isinstance(fill_value, dict)
5178+
else fill_value
5179+
)
5180+
51565181
var_shifts = {k: v for k, v in shifts.items() if k in var.dims}
5157-
variables[name] = var.shift(fill_value=fill_value, shifts=var_shifts)
5182+
variables[name] = var.shift(fill_value=fill_value_, shifts=var_shifts)
51585183
else:
51595184
variables[name] = var
51605185

xarray/core/merge.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,8 +666,10 @@ def merge(
666666
- "override": if indexes are of same size, rewrite indexes to be
667667
those of the first object with that dimension. Indexes for the same
668668
dimension must have the same size in all objects.
669-
fill_value : scalar, optional
670-
Value to use for newly missing values
669+
fill_value : scalar or dict-like, optional
670+
Value to use for newly missing values. If a dict-like, maps
671+
variable names to fill values. Use a data array's name to
672+
refer to its values.
671673
combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \
672674
default: "drop"
673675
String indicating how to combine attrs of the objects being merged:

xarray/tests/test_combine.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,18 +601,26 @@ def test_combine_concat_over_redundant_nesting(self):
601601
expected = Dataset({"x": [0]})
602602
assert_identical(expected, actual)
603603

604-
@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0])
604+
@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}])
605605
def test_combine_nested_fill_value(self, fill_value):
606606
datasets = [
607-
Dataset({"a": ("x", [2, 3]), "x": [1, 2]}),
608-
Dataset({"a": ("x", [1, 2]), "x": [0, 1]}),
607+
Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}),
608+
Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}),
609609
]
610610
if fill_value == dtypes.NA:
611611
# if we supply the default, we expect the missing value for a
612612
# float array
613-
fill_value = np.nan
613+
fill_value_a = fill_value_b = np.nan
614+
elif isinstance(fill_value, dict):
615+
fill_value_a = fill_value["a"]
616+
fill_value_b = fill_value["b"]
617+
else:
618+
fill_value_a = fill_value_b = fill_value
614619
expected = Dataset(
615-
{"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])},
620+
{
621+
"a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]),
622+
"b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]),
623+
},
616624
{"x": [0, 1, 2]},
617625
)
618626
actual = combine_nested(datasets, concat_dim="t", fill_value=fill_value)

xarray/tests/test_concat.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,18 +349,26 @@ def test_concat_multiindex(self):
349349
assert expected.equals(actual)
350350
assert isinstance(actual.x.to_index(), pd.MultiIndex)
351351

352-
@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0])
352+
@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}])
353353
def test_concat_fill_value(self, fill_value):
354354
datasets = [
355-
Dataset({"a": ("x", [2, 3]), "x": [1, 2]}),
356-
Dataset({"a": ("x", [1, 2]), "x": [0, 1]}),
355+
Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}),
356+
Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}),
357357
]
358358
if fill_value == dtypes.NA:
359359
# if we supply the default, we expect the missing value for a
360360
# float array
361-
fill_value = np.nan
361+
fill_value_a = fill_value_b = np.nan
362+
elif isinstance(fill_value, dict):
363+
fill_value_a = fill_value["a"]
364+
fill_value_b = fill_value["b"]
365+
else:
366+
fill_value_a = fill_value_b = fill_value
362367
expected = Dataset(
363-
{"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])},
368+
{
369+
"a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]),
370+
"b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]),
371+
},
364372
{"x": [0, 1, 2]},
365373
)
366374
actual = concat(datasets, dim="t", fill_value=fill_value)

0 commit comments

Comments
 (0)