|
2 | 2 |
|
3 | 3 | using Base.Iterators: Enumerate
|
4 | 4 |
|
| 5 | +struct AbortMapException <: Exception end |
| 6 | + |
5 | 7 | """
|
6 | 8 | asyncmap(f, c...; ntasks=0, batch_size=nothing)
|
7 | 9 |
|
@@ -104,30 +106,120 @@ julia> asyncmap(batch_func, 1:5; ntasks=2, batch_size=2)
|
104 | 106 | worker invocation, etc. See [`Threads`](@ref) and `Distributed` for alternatives.
|
105 | 107 |
|
106 | 108 | """
|
107 |
| -function asyncmap(f, c...; ntasks=0, batch_size=nothing) |
108 |
| - return async_usemap(f, c...; ntasks=ntasks, batch_size=batch_size) |
109 |
| -end |
| 109 | +asyncmap(f, c...; kwargs...) = do_asyncmap(f, c...; kwargs...) |
110 | 110 |
|
111 |
| -function async_usemap(f, c...; ntasks=0, batch_size=nothing) |
112 |
| - ntasks = verify_ntasks(c[1], ntasks) |
| 111 | +function do_asyncmap(f, c...; ntasks=0, batch_size=nothing) |
| 112 | + ntasks = let n = verify_ntasks(c[1], ntasks) |
| 113 | + n isa Number ? () -> n : n |
| 114 | + end |
113 | 115 | batch_size = verify_batch_size(batch_size)
|
114 | 116 |
|
115 |
| - if batch_size !== nothing |
116 |
| - exec_func = batch -> begin |
117 |
| - # extract the Refs from the input tuple |
118 |
| - batch_refs = map(x->x[1], batch) |
| 117 | + (tasks, echnl, res) = if batch_size === nothing |
| 118 | + do_asyncmap(f, c, ntasks) |
| 119 | + else |
| 120 | + do_asyncmap(f, c, ntasks, batch_size) |
| 121 | + end |
119 | 122 |
|
120 |
| - # and the args tuple.... |
121 |
| - batched_args = map(x->x[2], batch) |
| 123 | + foreach(_wait, tasks) |
| 124 | + close(echnl) |
122 | 125 |
|
123 |
| - results = f(batched_args) |
124 |
| - foreach(x -> (batch_refs[x[1]].x = x[2]), enumerate(results)) |
125 |
| - end |
| 126 | + isready(echnl) && |
| 127 | + throw(CompositeException(TaskFailedException.(echnl))) |
| 128 | + |
| 129 | + if res isa Ref |
| 130 | + # scalar case: map(f, x...) = f(x...) |
| 131 | + res.x |
126 | 132 | else
|
127 |
| - exec_func = (r,args) -> (r.x = f(args...)) |
| 133 | + map(ref -> ref.x, res) |
| 134 | + end |
| 135 | +end |
| 136 | + |
| 137 | +mutable struct AsyncmapState |
| 138 | + tasks::Vector{Task} |
| 139 | + ntasks::Function |
| 140 | + concur::Int |
| 141 | + echnl::Channel{Task} |
| 142 | +end |
| 143 | + |
| 144 | +AsyncmapState(ntasks) = |
| 145 | + AsyncmapState([], ntasks, 0, Channel{Task}(typemax(Int))) |
| 146 | + |
| 147 | +function do_asyncmap_task!(f::Function, s::AsyncmapState, xs...) |
| 148 | + t = @task try |
| 149 | + f(xs...) |
| 150 | + catch |
| 151 | + put!(s.echnl, current_task()) |
| 152 | + rethrow() |
| 153 | + finally |
| 154 | + s.concur -= 1 |
| 155 | + end |
| 156 | + |
| 157 | + while s.concur >= s.ntasks() && !isready(s.echnl) |
| 158 | + yield() |
| 159 | + end |
| 160 | + |
| 161 | + isready(s.echnl) && throw(AbortMapException()); |
| 162 | + |
| 163 | + s.concur += 1 |
| 164 | + schedule(t) |
| 165 | + push!(s.tasks, t) |
| 166 | +end |
| 167 | + |
| 168 | +function do_asyncmap(f, c::Tuple, ntasks::Function) |
| 169 | + s = AsyncmapState(ntasks) |
| 170 | + |
| 171 | + res = try |
| 172 | + map(c...) do x... |
| 173 | + r = Ref{Any}(nothing) |
| 174 | + |
| 175 | + do_asyncmap_task!(s) do |
| 176 | + r.x = f(x...) |
| 177 | + end |
| 178 | + |
| 179 | + r |
| 180 | + end |
| 181 | + catch ex |
| 182 | + ex isa AbortMapException || rethrow() |
128 | 183 | end
|
129 |
| - chnl, err_chnl, worker_tasks = setup_chnl_and_tasks(exec_func, ntasks, batch_size) |
130 |
| - return wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...) |
| 184 | + |
| 185 | + s.tasks, s.echnl, res |
| 186 | +end |
| 187 | + |
| 188 | +function do_asyncmap(f, c::Tuple, ntasks::Function, batch_size) |
| 189 | + s = AsyncmapState(ntasks) |
| 190 | + local batch_in |
| 191 | + local batch_out |
| 192 | + |
| 193 | + function new_batch() |
| 194 | + batch_in = sizehint!(Vector{Any}(), batch_size) |
| 195 | + batch_out = sizehint!(Vector{Ref{Any}}(), batch_size) |
| 196 | + end |
| 197 | + |
| 198 | + exec_batch(args, res) = for (i, x) in enumerate(f(args)) |
| 199 | + res[i].x = x |
| 200 | + end |
| 201 | + |
| 202 | + new_batch() |
| 203 | + res = try |
| 204 | + map(c...) do x |
| 205 | + r = Ref{Any}(undef) |
| 206 | + |
| 207 | + push!(batch_in, x) |
| 208 | + push!(batch_out, r) |
| 209 | + |
| 210 | + if length(batch_in) >= batch_size |
| 211 | + do_asyncmap_task!(exec_batch, s, batch_in, batch_out) |
| 212 | + new_batch() |
| 213 | + end |
| 214 | + |
| 215 | + r |
| 216 | + end |
| 217 | + catch ex |
| 218 | + ex isa AbortMapException || rethrow() |
| 219 | + end |
| 220 | + do_asyncmap_task!(exec_batch, s, batch_in, batch_out) |
| 221 | + |
| 222 | + s.tasks, s.echnl, res |
131 | 223 | end
|
132 | 224 |
|
133 | 225 | batch_size_err_str(batch_size) = string("batch_size must be specified as a positive integer. batch_size=", batch_size)
|
@@ -161,63 +253,6 @@ function verify_ntasks(iterable, ntasks)
|
161 | 253 | return ntasks
|
162 | 254 | end
|
163 | 255 |
|
164 |
| -function wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...) |
165 |
| - # The driver task, creates a Ref object and writes it and the args tuple to |
166 |
| - # the communication channel for processing by a free worker task. |
167 |
| - push_arg_to_channel = (x...) -> (r=Ref{Any}(nothing); put!(chnl,(r,x));r) |
168 |
| - |
169 |
| - if isa(ntasks, Function) |
170 |
| - map_f = (x...) -> begin |
171 |
| - # check number of tasks every time, and start one if required. |
172 |
| - # number_tasks > optimal_number is fine, the other way around is inefficient. |
173 |
| - if length(worker_tasks) < ntasks() |
174 |
| - start_worker_task!(worker_tasks, exec_func, chnl, err_chnl) |
175 |
| - end |
176 |
| - push_arg_to_channel(x...) |
177 |
| - end |
178 |
| - else |
179 |
| - map_f = push_arg_to_channel |
180 |
| - end |
181 |
| - maptwice(map_f, chnl, err_chnl, worker_tasks, c...) |
182 |
| -end |
183 |
| - |
184 |
| -function maptwice(wrapped_f, chnl, err_chnl, worker_tasks, c...) |
185 |
| - # first run, returns a collection of Refs |
186 |
| - asyncrun = try |
187 |
| - map(wrapped_f, c...) |
188 |
| - catch ex |
189 |
| - # the work channel was closed early and we have the error |
190 |
| - ex isa InvalidStateException && isready(err_chnl) || rethrow() |
191 |
| - finally |
192 |
| - close(chnl) |
193 |
| - end |
194 |
| - |
195 |
| - # wait for all worker tasks to finish. We try not to throw |
196 |
| - # task errors here because they would be in task creation |
197 |
| - # order. |
198 |
| - try |
199 |
| - @sync foreach(t -> @sync_add(t), worker_tasks) |
200 |
| - catch |
201 |
| - isready(err_chnl) || rethrow() |
202 |
| - finally |
203 |
| - close(err_chnl) |
204 |
| - end |
205 |
| - |
206 |
| - # throw task errors in chronological order |
207 |
| - if isready(err_chnl) |
208 |
| - exs = [TaskFailedException(task) for task in err_chnl] |
209 |
| - throw(CompositeException(exs)) |
210 |
| - end |
211 |
| - |
212 |
| - if isa(asyncrun, Ref) |
213 |
| - # scalar case |
214 |
| - return asyncrun.x |
215 |
| - else |
216 |
| - # second run, extract values from the Refs and return |
217 |
| - return map(ref->ref.x, asyncrun) |
218 |
| - end |
219 |
| -end |
220 |
| - |
221 | 256 | function setup_chnl_and_tasks(exec_func, ntasks, batch_size=nothing)
|
222 | 257 | if isa(ntasks, Function)
|
223 | 258 | nt = ntasks()
|
|
279 | 314 |
|
280 | 315 | # map on a single BitArray returns a BitArray if the mapping function is boolean.
|
281 | 316 | function asyncmap(f, b::BitArray; kwargs...)
|
282 |
| - b2 = async_usemap(f, b; kwargs...) |
| 317 | + b2 = do_asyncmap(f, b; kwargs...) |
283 | 318 | if eltype(b2) == Bool
|
284 | 319 | return BitArray(b2)
|
285 | 320 | end
|
|
0 commit comments