1
+ from typing import TYPE_CHECKING , Hashable , Iterable , Optional , Union , overload
2
+
1
3
from .computation import dot
2
- from typing import TYPE_CHECKING , Hashable , Iterable , Optional , Tuple , Union , overload
3
4
4
5
if TYPE_CHECKING :
5
6
from .dataarray import DataArray , Dataset
50
51
"""
51
52
52
53
53
- def _maybe_get_all_dims (
54
- dims : Optional [Union [Hashable , Iterable [Hashable ]]],
55
- dims1 : Tuple [Hashable , ...],
56
- dims2 : Tuple [Hashable , ...],
57
- ):
58
- """ the union of dims1 and dims2 if dims is None
59
-
60
- `dims=None` behaves differently in `dot` and `sum`, so we have to apply
61
- `dot` over the union of the dimensions
62
-
63
- """
64
-
65
- if dims is None :
66
- dims = tuple (sorted (set (dims1 ) | set (dims2 )))
67
-
68
- return dims
69
-
70
-
71
54
class Weighted :
72
55
"""A object that implements weighted operations.
73
56
@@ -86,13 +69,11 @@ class Weighted:
86
69
def __init__ (self , obj : "DataArray" , weights : "DataArray" ) -> None :
87
70
...
88
71
89
- @overload # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated
72
+ @overload # noqa: F811
90
73
def __init__ (self , obj : "Dataset" , weights : "DataArray" ) -> None :
91
74
...
92
75
93
- def __init__ ( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated
94
- self , obj , weights
95
- ):
76
+ def __init__ (self , obj , weights ): # noqa: F811
96
77
"""
97
78
Create a Weighted object
98
79
@@ -107,7 +88,7 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat
107
88
108
89
Note
109
90
----
110
- Missing values in the weights are replaced with 0. (i.e. no weight) .
91
+ Weights can not contain missing values .
111
92
112
93
"""
113
94
@@ -117,21 +98,28 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat
117
98
assert isinstance (weights , DataArray ), msg
118
99
119
100
self .obj = obj
120
- self .weights = weights .fillna (0 )
101
+
102
+ if weights .isnull ().any ():
103
+ raise ValueError ("`weights` cannot contain missing values." )
104
+
105
+ self .weights = weights
121
106
122
107
def _sum_of_weights (
123
- self , da : "DataArray" , dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None
108
+ self ,
109
+ da : "DataArray" ,
110
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
124
111
) -> "DataArray" :
125
112
""" Calculate the sum of weights, accounting for missing values """
126
113
127
- # we need to mask DATA values that are nan; else the weights are wrong
128
- mask = ~ da .isnull ()
114
+ # we need to mask data values that are nan; else the weights are wrong
115
+ mask = da .notnull ()
129
116
130
117
# need to infer dims as we use `dot`
131
- dims = _maybe_get_all_dims (dim , da .dims , self .weights .dims )
118
+ if dim is None :
119
+ dim = ...
132
120
133
121
# use `dot` to avoid creating large DataArrays (if da and weights do not share all dims)
134
- sum_of_weights = dot (mask , self .weights , dims = dims )
122
+ sum_of_weights = dot (mask , self .weights , dims = dim )
135
123
136
124
# find all weights that are valid (not 0)
137
125
valid_weights = sum_of_weights != 0.0
@@ -148,15 +136,16 @@ def _weighted_sum(
148
136
"""Reduce a DataArray by a by a weighted `sum` along some dimension(s)."""
149
137
150
138
# need to infer dims as we use `dot`
151
- dims = _maybe_get_all_dims (dim , da .dims , self .weights .dims )
139
+ if dim is None :
140
+ dim = ...
152
141
153
142
# use `dot` to avoid creating large DataArrays
154
143
155
144
# need to mask invalid DATA as dot does not implement skipna
156
- if skipna or skipna is None :
157
- return dot (da .fillna (0.0 ), self .weights , dims = dims )
145
+ if skipna or ( skipna is None and da . dtype . kind in "cfO" ) :
146
+ return dot (da .fillna (0.0 ), self .weights , dims = dim )
158
147
159
- return dot (da , self .weights , dims = dims )
148
+ return dot (da , self .weights , dims = dim )
160
149
161
150
def _weighted_mean (
162
151
self ,
@@ -207,7 +196,7 @@ def __repr__(self):
207
196
208
197
msg = "{klass} with weights along dimensions: {weight_dims}"
209
198
return msg .format (
210
- klass = self .__class__ .__name__ , weight_dims = ", " .join (self .weights .dims )
199
+ klass = self .__class__ .__name__ , weight_dims = ", " .join (self .weights .dims ),
211
200
)
212
201
213
202
@@ -217,18 +206,6 @@ def _implementation(self, func, **kwargs):
217
206
return func (self .obj , ** kwargs )
218
207
219
208
220
- # add docstrings
221
- DataArrayWeighted .sum_of_weights .__doc__ = _SUM_OF_WEIGHTS_DOCSTRING .format (
222
- cls = "DataArray"
223
- )
224
- DataArrayWeighted .mean .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
225
- cls = "DataArray" , fcn = "mean"
226
- )
227
- DataArrayWeighted .sum .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
228
- cls = "DataArray" , fcn = "sum"
229
- )
230
-
231
-
232
209
class DatasetWeighted (Weighted ):
233
210
def _implementation (self , func , ** kwargs ) -> "Dataset" :
234
211
@@ -242,11 +219,15 @@ def _implementation(self, func, **kwargs) -> "Dataset":
242
219
return Dataset (weighted , coords = self .obj .coords )
243
220
244
221
245
- # add docstring
246
- DatasetWeighted .sum_of_weights .__doc__ = _SUM_OF_WEIGHTS_DOCSTRING .format (cls = "Dataset" )
247
- DatasetWeighted .mean .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
248
- cls = "Dataset" , fcn = "mean"
249
- )
250
- DatasetWeighted .sum .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
251
- cls = "Dataset" , fcn = "sum"
252
- )
222
+ def _inject_docstring (cls , cls_name ):
223
+
224
+ cls .sum_of_weights .__doc__ = _SUM_OF_WEIGHTS_DOCSTRING .format (cls = cls_name )
225
+
226
+ for operator in ["sum" , "mean" ]:
227
+ getattr (cls , operator ).__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
228
+ cls = cls_name , fcn = operator
229
+ )
230
+
231
+
232
+ _inject_docstring (DataArrayWeighted , "DataArray" )
233
+ _inject_docstring (DatasetWeighted , "Dataset" )
0 commit comments