@@ -35,7 +35,7 @@ def make_meta(obj):
35
35
from dask .array .utils import meta_from_array
36
36
37
37
if isinstance (obj , DataArray ):
38
- meta = DataArray (obj .data ._meta , dims = obj .dims )
38
+ meta = DataArray (obj .data ._meta , dims = obj .dims , name = obj . name )
39
39
40
40
if isinstance (obj , Dataset ):
41
41
meta = Dataset ()
@@ -45,9 +45,14 @@ def make_meta(obj):
45
45
else :
46
46
meta_obj = meta_from_array (obj [name ].data )
47
47
meta [name ] = DataArray (meta_obj , dims = obj [name ].dims )
48
+ # meta[name] = DataArray(obj[name].dims, meta_obj)
48
49
else :
49
50
meta = obj
50
51
52
+ # TODO: deal with non-dim coords
53
+ # for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names!
54
+ # coord = obj[coord_name]
55
+
51
56
return meta
52
57
53
58
@@ -65,7 +70,7 @@ def infer_template(func, obj, *args, **kwargs):
65
70
return template
66
71
67
72
68
- def _make_dict (x ):
73
+ def make_dict (x ):
69
74
# Dataset.to_dict() is too complicated
70
75
# maps variable name to numpy array
71
76
if isinstance (x , DataArray ):
@@ -93,6 +98,9 @@ def map_blocks(func, obj, *args, **kwargs):
93
98
properties of the returned object such as dtype, variable names,
94
99
new dimensions and new indexes (if any).
95
100
101
+ This function must
102
+ - return either a DataArray or a Dataset
103
+
96
104
This function cannot
97
105
- change size of existing dimensions.
98
106
- add new chunked dimensions.
@@ -101,18 +109,24 @@ def map_blocks(func, obj, *args, **kwargs):
101
109
Chunks of this object will be provided to 'func'. The function must not change
102
110
shape of the provided DataArray.
103
111
args:
104
- Passed on to func.
112
+ Passed on to func. Cannot include chunked xarray objects.
105
113
kwargs:
106
- Passed on to func.
114
+ Passed on to func. Cannot include chunked xarray objects.
107
115
108
116
109
117
Returns
110
118
-------
111
119
DataArray or Dataset
112
120
121
+ Notes
122
+ -----
123
+
124
+ This function is designed to work with dask-backed xarray objects. See apply_ufunc for
125
+ a similar function that works with numpy arrays.
126
+
113
127
See Also
114
128
--------
115
- dask.array.map_blocks
129
+ dask.array.map_blocks, xarray.apply_ufunc
116
130
"""
117
131
118
132
def _wrapper (func , obj , to_array , args , kwargs ):
@@ -129,7 +143,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
129
143
% name
130
144
)
131
145
132
- to_return = _make_dict (result )
146
+ to_return = make_dict (result )
133
147
134
148
return to_return
135
149
@@ -149,26 +163,30 @@ def _wrapper(func, obj, to_array, args, kwargs):
149
163
if isinstance (template , DataArray ):
150
164
result_is_array = True
151
165
template = template ._to_temp_dataset ()
152
- else :
166
+ elif isinstance ( template , Dataset ) :
153
167
result_is_array = False
168
+ else :
169
+ raise ValueError (
170
+ "Function must return an xarray DataArray or Dataset. Instead it returned %r"
171
+ % type (template )
172
+ )
154
173
155
174
# If two different variables have different chunking along the same dim
156
175
# .chunks will raise an error.
157
176
input_chunks = dataset .chunks
158
177
159
- indexes = dict (dataset .indexes )
160
- for dim in template .indexes :
161
- if dim not in indexes :
162
- indexes [dim ] = template .indexes [dim ]
178
+ # TODO: add a test that fails when template and dataset are switched
179
+ indexes = dict (template .indexes )
180
+ indexes .update (dataset .indexes )
163
181
164
182
graph = {}
165
183
gname = "%s-%s" % (dask .utils .funcname (func ), dask .base .tokenize (dataset ))
166
184
167
185
# map dims to list of chunk indexes
168
- ichunk = {dim : range (len (input_chunks [ dim ] )) for dim in input_chunks }
186
+ ichunk = {dim : range (len (chunks_v )) for dim , chunks_v in input_chunks . items () }
169
187
# mapping from chunk index to slice bounds
170
188
chunk_index_bounds = {
171
- dim : np .cumsum ((0 ,) + input_chunks [ dim ] ) for dim in input_chunks
189
+ dim : np .cumsum ((0 ,) + chunks_v ) for dim , chunks_v in input_chunks . items ()
172
190
}
173
191
174
192
# iterate over all possible chunk combinations
@@ -185,17 +203,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
185
203
for name , variable in dataset .variables .items ():
186
204
# make a task that creates tuple of (dims, chunk)
187
205
if dask .is_dask_collection (variable .data ):
188
- var_dask_keys = variable .__dask_keys__ ()
189
-
190
206
# recursively index into dask_keys nested list to get chunk
191
- chunk = var_dask_keys
207
+ chunk = variable . __dask_keys__ ()
192
208
for dim in variable .dims :
193
209
chunk = chunk [chunk_index_dict [dim ]]
194
210
195
- task_name = ("tuple-" + dask .base .tokenize (chunk ),) + v
196
- graph [task_name ] = (tuple , [variable .dims , chunk ])
211
+ chunk_variable_task = ("tuple-" + dask .base .tokenize (chunk ),) + v
212
+ graph [chunk_variable_task ] = (tuple , [variable .dims , chunk ])
197
213
else :
198
- # numpy array with possibly chunked dimensions
214
+ # non-dask array with possibly chunked dimensions
199
215
# index into variable appropriately
200
216
subsetter = dict ()
201
217
for dim in variable .dims :
@@ -207,14 +223,14 @@ def _wrapper(func, obj, to_array, args, kwargs):
207
223
)
208
224
209
225
subset = variable .isel (subsetter )
210
- task_name = (name + dask .base .tokenize (subset ),) + v
211
- graph [task_name ] = (tuple , [subset .dims , subset ])
226
+ chunk_variable_task = (name + dask .base .tokenize (subset ),) + v
227
+ graph [chunk_variable_task ] = (tuple , [subset .dims , subset ])
212
228
213
229
# this task creates dict mapping variable name to above tuple
214
- if name in dataset .data_vars :
215
- data_vars .append ([name , task_name ])
216
- if name in dataset . coords :
217
- coords .append ([name , task_name ])
230
+ if name in dataset ._coord_names :
231
+ coords .append ([name , chunk_variable_task ])
232
+ else :
233
+ data_vars .append ([name , chunk_variable_task ])
218
234
219
235
from_wrapper = (gname ,) + v
220
236
graph [from_wrapper ] = (
@@ -229,14 +245,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
229
245
# mapping from variable name to dask graph key
230
246
var_key_map = {}
231
247
for name , variable in template .variables .items ():
232
- var_dims = variable .dims
248
+ if name in indexes :
249
+ continue
233
250
# cannot tokenize "name" because the hash of <this-array> is not invariant!
234
251
# This happens when the user function does not set a name on the returned DataArray
235
252
gname_l = "%s-%s" % (gname , name )
236
253
var_key_map [name ] = gname_l
237
254
238
255
key = (gname_l ,)
239
- for dim in var_dims :
256
+ for dim in variable . dims :
240
257
if dim in chunk_index_dict :
241
258
key += (chunk_index_dict [dim ],)
242
259
else :
@@ -248,26 +265,30 @@ def _wrapper(func, obj, to_array, args, kwargs):
248
265
graph = HighLevelGraph .from_collections (name , graph , dependencies = [dataset ])
249
266
250
267
result = Dataset ()
251
- for var , key in var_key_map .items ():
268
+ # a quicker way to assign indexes?
269
+ for name in template .indexes :
270
+ result [name ] = indexes [name ]
271
+ for name , key in var_key_map .items ():
252
272
# indexes need to be known
253
273
# otherwise compute is called when DataArray is created
254
- if var in indexes :
255
- result [var ] = indexes [var ]
274
+ if name in indexes :
275
+ result [name ] = indexes [name ]
256
276
continue
257
277
258
- dims = template [var ].dims
278
+ dims = template [name ].dims
259
279
var_chunks = []
260
280
for dim in dims :
261
281
if dim in input_chunks :
262
282
var_chunks .append (input_chunks [dim ])
263
- else :
264
- if dim in indexes :
265
- var_chunks .append ((len (indexes [dim ]),))
283
+ elif dim in indexes :
284
+ var_chunks .append ((len (indexes [dim ]),))
266
285
267
286
data = dask .array .Array (
268
- graph , name = key , chunks = var_chunks , dtype = template [var ].dtype
287
+ graph , name = key , chunks = var_chunks , dtype = template [name ].dtype
269
288
)
270
- result [var ] = DataArray (data = data , dims = dims , name = var )
289
+ result [name ] = (dims , data )
290
+
291
+ result = result .set_coords (template ._coord_names )
271
292
272
293
if result_is_array :
273
294
result = _to_array (result )
0 commit comments