Skip to content

Commit e2ad69e

Browse files
committed
update weighted.py
1 parent c646568 commit e2ad69e

File tree

1 file changed

+36
-55
lines changed

1 file changed

+36
-55
lines changed

xarray/core/weighted.py

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
2+
13
from .computation import dot
2-
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Tuple, Union, overload
34

45
if TYPE_CHECKING:
56
from .dataarray import DataArray, Dataset
@@ -50,24 +51,6 @@
5051
"""
5152

5253

53-
def _maybe_get_all_dims(
54-
dims: Optional[Union[Hashable, Iterable[Hashable]]],
55-
dims1: Tuple[Hashable, ...],
56-
dims2: Tuple[Hashable, ...],
57-
):
58-
""" the union of dims1 and dims2 if dims is None
59-
60-
`dims=None` behaves differently in `dot` and `sum`, so we have to apply
61-
`dot` over the union of the dimensions
62-
63-
"""
64-
65-
if dims is None:
66-
dims = tuple(sorted(set(dims1) | set(dims2)))
67-
68-
return dims
69-
70-
7154
class Weighted:
7255
"""A object that implements weighted operations.
7356
@@ -86,13 +69,11 @@ class Weighted:
8669
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
8770
...
8871

89-
@overload # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated
72+
@overload # noqa: F811
9073
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
9174
...
9275

93-
def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated
94-
self, obj, weights
95-
):
76+
def __init__(self, obj, weights): # noqa: F811
9677
"""
9778
Create a Weighted object
9879
@@ -107,7 +88,7 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat
10788
10889
Note
10990
----
110-
Missing values in the weights are replaced with 0. (i.e. no weight).
91+
Weights can not contain missing values.
11192
11293
"""
11394

@@ -117,21 +98,28 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat
11798
assert isinstance(weights, DataArray), msg
11899

119100
self.obj = obj
120-
self.weights = weights.fillna(0)
101+
102+
if weights.isnull().any():
103+
raise ValueError("`weights` cannot contain missing values.")
104+
105+
self.weights = weights
121106

122107
def _sum_of_weights(
123-
self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None
108+
self,
109+
da: "DataArray",
110+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
124111
) -> "DataArray":
125112
""" Calculate the sum of weights, accounting for missing values """
126113

127-
# we need to mask DATA values that are nan; else the weights are wrong
128-
mask = ~da.isnull()
114+
# we need to mask data values that are nan; else the weights are wrong
115+
mask = da.notnull()
129116

130117
# need to infer dims as we use `dot`
131-
dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims)
118+
if dim is None:
119+
dim = ...
132120

133121
# use `dot` to avoid creating large DataArrays (if da and weights do not share all dims)
134-
sum_of_weights = dot(mask, self.weights, dims=dims)
122+
sum_of_weights = dot(mask, self.weights, dims=dim)
135123

136124
# find all weights that are valid (not 0)
137125
valid_weights = sum_of_weights != 0.0
@@ -148,15 +136,16 @@ def _weighted_sum(
148136
"""Reduce a DataArray by a by a weighted `sum` along some dimension(s)."""
149137

150138
# need to infer dims as we use `dot`
151-
dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims)
139+
if dim is None:
140+
dim = ...
152141

153142
# use `dot` to avoid creating large DataArrays
154143

155144
# need to mask invalid DATA as dot does not implement skipna
156-
if skipna or skipna is None:
157-
return dot(da.fillna(0.0), self.weights, dims=dims)
145+
if skipna or (skipna is None and da.dtype.kind in "cfO"):
146+
return dot(da.fillna(0.0), self.weights, dims=dim)
158147

159-
return dot(da, self.weights, dims=dims)
148+
return dot(da, self.weights, dims=dim)
160149

161150
def _weighted_mean(
162151
self,
@@ -207,7 +196,7 @@ def __repr__(self):
207196

208197
msg = "{klass} with weights along dimensions: {weight_dims}"
209198
return msg.format(
210-
klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims)
199+
klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims),
211200
)
212201

213202

@@ -217,18 +206,6 @@ def _implementation(self, func, **kwargs):
217206
return func(self.obj, **kwargs)
218207

219208

220-
# add docstrings
221-
DataArrayWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(
222-
cls="DataArray"
223-
)
224-
DataArrayWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
225-
cls="DataArray", fcn="mean"
226-
)
227-
DataArrayWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
228-
cls="DataArray", fcn="sum"
229-
)
230-
231-
232209
class DatasetWeighted(Weighted):
233210
def _implementation(self, func, **kwargs) -> "Dataset":
234211

@@ -242,11 +219,15 @@ def _implementation(self, func, **kwargs) -> "Dataset":
242219
return Dataset(weighted, coords=self.obj.coords)
243220

244221

245-
# add docstring
246-
DatasetWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls="Dataset")
247-
DatasetWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
248-
cls="Dataset", fcn="mean"
249-
)
250-
DatasetWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
251-
cls="Dataset", fcn="sum"
252-
)
222+
def _inject_docstring(cls, cls_name):
223+
224+
cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name)
225+
226+
for operator in ["sum", "mean"]:
227+
getattr(cls, operator).__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
228+
cls=cls_name, fcn=operator
229+
)
230+
231+
232+
_inject_docstring(DataArrayWeighted, "DataArray")
233+
_inject_docstring(DatasetWeighted, "Dataset")

0 commit comments

Comments
 (0)