Skip to content

Commit ccbffbc

Browse files
committed
Deliver errors in chronological order and update AsyncCollector
1 parent 8383552 commit ccbffbc

File tree

3 files changed

+85
-56
lines changed

3 files changed

+85
-56
lines changed

base/asyncmap.jl

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ return a vector of results. The input vector will have a length of `batch_size`
2424
2525
# Exception handling
2626
27-
Individual exceptions thrown by `f` will be wrapped in a `TaskFailedException`.
27+
Individual exceptions thrown by `f` will be wrapped in a `CapturedException`.
2828
As multiple tasks are used, more than one exception may be thrown. Exceptions
2929
are combined into a `CompositeException`. Even if only a single exception is
3030
thrown, it is still wrapped in a `CompositeException`.
3131
32-
However, when an exception is thrown `asyncmap` will fail-fast, canceling any
33-
remaining work. If you need `asyncmap` to be error resistant then wrap the
34-
body of 'f' in a `try... catch` statement. Below is one possible approach to
35-
error handling:
32+
However, when an exception is thrown `asyncmap` will fail-fast. Any remaining
33+
iterations, which are not already in progress, will be cancelled. If you need
34+
`asyncmap` to be error resistant then wrap the body of 'f' in a `try... catch`
35+
statement. Below is one possible approach to error handling:
3636
3737
```
3838
julia> result = asyncmap(1:2) do x
@@ -99,9 +99,9 @@ julia> asyncmap(batch_func, 1:5; ntasks=2, batch_size=2)
9999
```
100100
101101
!!! note
102-
Currently, all tasks in Julia are executed in a single OS thread co-operatively. Consequently,
102+
The tasks created by `asyncmap` are executed in a single OS thread co-operatively. Consequently,
103103
`asyncmap` is beneficial only when the mapping function involves any I/O - disk, network, remote
104-
worker invocation, etc.
104+
worker invocation, etc. See [`Threads`](@ref) and [`Distributed`](@ref) for alternatives.
105105
106106
"""
107107
function asyncmap(f, c...; ntasks=0, batch_size=nothing)
@@ -126,8 +126,8 @@ function async_usemap(f, c...; ntasks=0, batch_size=nothing)
126126
else
127127
exec_func = (r,args) -> (r.x = f(args...))
128128
end
129-
chnl, worker_tasks = setup_chnl_and_tasks(exec_func, ntasks, batch_size)
130-
return wrap_n_exec_twice(chnl, worker_tasks, ntasks, exec_func, c...)
129+
chnl, err_chnl, worker_tasks = setup_chnl_and_tasks(exec_func, ntasks, batch_size)
130+
return wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...)
131131
end
132132

133133
batch_size_err_str(batch_size) = string("batch_size must be specified as a positive integer. batch_size=", batch_size)
@@ -161,7 +161,7 @@ function verify_ntasks(iterable, ntasks)
161161
return ntasks
162162
end
163163

164-
function wrap_n_exec_twice(chnl, worker_tasks, ntasks, exec_func, c...)
164+
function wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...)
165165
# The driver task, creates a Ref object and writes it and the args tuple to
166166
# the communication channel for processing by a free worker task.
167167
push_arg_to_channel = (x...) -> (r=Ref{Any}(nothing); put!(chnl,(r,x));r)
@@ -171,41 +171,32 @@ function wrap_n_exec_twice(chnl, worker_tasks, ntasks, exec_func, c...)
171171
# check number of tasks every time, and start one if required.
172172
# number_tasks > optimal_number is fine, the other way around is inefficient.
173173
if length(worker_tasks) < ntasks()
174-
start_worker_task!(worker_tasks, exec_func, chnl)
174+
start_worker_task!(worker_tasks, exec_func, chnl, err_chnl)
175175
end
176176
push_arg_to_channel(x...)
177177
end
178178
else
179179
map_f = push_arg_to_channel
180180
end
181-
maptwice(map_f, chnl, worker_tasks, c...)
181+
maptwice(map_f, chnl, err_chnl, worker_tasks, c...)
182182
end
183183

184-
function maptwice(wrapped_f, chnl, worker_tasks, c...)
184+
function maptwice(wrapped_f, chnl, err_chnl, worker_tasks, c...)
185185
# first run, returns a collection of Refs
186-
asyncrun_excp = nothing
187186
local asyncrun
188187
try
189188
asyncrun = map(wrapped_f, c...)
190189
catch ex
191-
if isa(ex,InvalidStateException)
192-
# channel could be closed due to exceptions in the async tasks,
193-
# we propagate those errors, if any, over the `put!` failing
194-
# in asyncrun due to a closed channel.
195-
asyncrun_excp = ex
196-
else
197-
rethrow()
198-
end
190+
put!(err_chnl, CapturedException(ex, catch_backtrace()))
199191
end
200192

201193
# close channel and wait for all worker tasks to finish
202194
close(chnl)
203-
204-
# check and throw any exceptions from the worker tasks
205195
@sync foreach(t -> @sync_add(t), worker_tasks)
206196

207-
# check if there was a genuine problem with asyncrun
208-
(asyncrun_excp !== nothing) && throw(asyncrun_excp)
197+
# check for errors and throw them in chronological order
198+
close(err_chnl)
199+
isready(err_chnl) && throw(CompositeException(collect(err_chnl)))
209200

210201
if isa(asyncrun, Ref)
211202
# scalar case
@@ -231,35 +222,36 @@ function setup_chnl_and_tasks(exec_func, ntasks, batch_size=nothing)
231222
# of an error in any of the worker tasks, the channel is closed. This
232223
# results in the `put!` in the driver task failing immediately.
233224
chnl = Channel(0)
225+
err_chnl = Channel(nt)
234226
worker_tasks = []
235-
foreach(_ -> start_worker_task!(worker_tasks, exec_func, chnl, batch_size), 1:nt)
227+
foreach(_ -> start_worker_task!(worker_tasks, exec_func, chnl, err_chnl, batch_size), 1:nt)
236228
yield()
237-
return (chnl, worker_tasks)
229+
return (chnl, err_chnl, worker_tasks)
238230
end
239231

240-
function start_worker_task!(worker_tasks, exec_func, chnl, batch_size=nothing)
232+
start_worker_task(exec_func, chnl) = foreach(exec_data -> exec_func(exec_data...), chnl)
233+
start_worker_task(exec_func, chnl, ::Nothing) = start_worker_task(exec_func, chnl)
234+
start_worker_task(exec_func, chnl, batch_size) = while isopen(chnl)
235+
# The mapping function expects an array of input args, as it processes
236+
# elements in a batch.
237+
batch_collection=Any[]
238+
n = 0
239+
for exec_data in chnl
240+
push!(batch_collection, exec_data)
241+
n += 1
242+
(n == batch_size) && break
243+
end
244+
if n > 0
245+
exec_func(batch_collection)
246+
end
247+
end
248+
249+
function start_worker_task!(worker_tasks, exec_func, chnl, err_chnl, batch_size=nothing)
241250
t = @async begin
242251
try
243-
if isa(batch_size, Number)
244-
while isopen(chnl)
245-
# The mapping function expects an array of input args, as it processes
246-
# elements in a batch.
247-
batch_collection=Any[]
248-
n = 0
249-
for exec_data in chnl
250-
push!(batch_collection, exec_data)
251-
n += 1
252-
(n == batch_size) && break
253-
end
254-
if n > 0
255-
exec_func(batch_collection)
256-
end
257-
end
258-
else
259-
for exec_data in chnl
260-
exec_func(exec_data...)
261-
end
262-
end
252+
start_worker_task(exec_func, chnl, batch_size)
253+
catch ex
254+
put!(err_chnl, CapturedException(ex, catch_backtrace()))
263255
finally
264256
close(chnl)
265257
end
@@ -320,10 +312,11 @@ end
320312

321313
mutable struct AsyncCollectorState
322314
chnl::Channel
315+
err_chnl::Channel
323316
worker_tasks::Array{Task,1}
324317
enum_state # enumerator state
325-
AsyncCollectorState(chnl::Channel, worker_tasks::Vector) =
326-
new(chnl, convert(Vector{Task}, worker_tasks))
318+
AsyncCollectorState(chnl::Channel, err_chnl::Channel, worker_tasks::Vector) =
319+
new(chnl, err_chnl, convert(Vector{Task}, worker_tasks))
327320
end
328321

329322
function iterate(itr::AsyncCollector)
@@ -343,16 +336,23 @@ function iterate(itr::AsyncCollector)
343336
else
344337
exec_func = (i,args) -> (itr.results[i]=itr.f(args...))
345338
end
346-
chnl, worker_tasks = setup_chnl_and_tasks((i,args) -> (itr.results[i]=itr.f(args...)), itr.ntasks, itr.batch_size)
347-
return iterate(itr, AsyncCollectorState(chnl, worker_tasks))
339+
340+
chnl, err_chnl, worker_tasks = setup_chnl_and_tasks(itr.ntasks, itr.batch_size) do i, args
341+
itr.results[i]=itr.f(args...)
342+
end
343+
344+
return iterate(itr, AsyncCollectorState(chnl, err_chnl, worker_tasks))
348345
end
349346

350347
function wait_done(itr::AsyncCollector, state::AsyncCollectorState)
351348
close(state.chnl)
352349

353350
# wait for all tasks to finish
354-
foreach(x->(v=fetch(x); isa(v, Exception) && throw(v)), state.worker_tasks)
351+
@sync foreach(t -> @sync_add(t), state.worker_tasks)
355352
empty!(state.worker_tasks)
353+
354+
close(state.err_chnl)
355+
isready(state.err_chnl) && throw(CompositeException(collect(state.err_chnl)))
356356
end
357357

358358
function iterate(itr::AsyncCollector, state::AsyncCollectorState)
@@ -369,7 +369,14 @@ function iterate(itr::AsyncCollector, state::AsyncCollectorState)
369369
return nothing
370370
end
371371
(i, args), state.enum_state = y
372-
put!(state.chnl, (i, args))
372+
373+
try
374+
put!(state.chnl, (i, args))
375+
catch ex
376+
put!(state.err_chnl, ex)
377+
wait_done(itr, state)
378+
rethrow() # Should never reach here
379+
end
373380

374381
return (nothing, state)
375382
end

stdlib/Distributed/test/distributed_exec.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,13 +589,18 @@ generic_map_tests(pmap_fallback)
589589
run_map_equivalence_tests(pmap)
590590
@test pmap(uppercase, "Hello World!") == map(uppercase, "Hello World!")
591591

592+
unpack(ex::CompositeException) = unpack(ex.exceptions[1])
593+
unpack(ex::CapturedException) = unpack(ex.ex)
594+
unpack(ex::RemoteException) = unpack(ex.captured)
595+
unpack(ex::TaskFailedException) = unpack(ex.task.exception)
596+
unpack(ex) = ex
592597

593598
# Simple test for pmap throws error
594599
let error_thrown = false
595600
try
596601
pmap(x -> x == 50 ? error("foobar") : x, 1:100)
597602
catch e
598-
@test e.exceptions[1].task.exception.captured.ex.msg == "foobar"
603+
@test unpack(e).msg == "foobar"
599604
error_thrown = true
600605
end
601606
@test error_thrown

test/asyncmap.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,20 @@ using Base.Unicode: uppercase
6969
@test_throws CompositeException asyncmap(1:4, batch_size=2) do v
7070
map(u -> iseven(u) && error("foo"), v)
7171
end
72+
73+
unpack(ex::CapturedException) = unpack(ex.ex)
74+
unpack(ex::TaskFailedException) = unpack(ex.task.exception)
75+
unpack(ex) = ex
76+
77+
# Make sure exceptions are thrown in the order they happen
78+
chnl = Channel()
79+
try
80+
asyncmap(1:2) do i
81+
i == 1 && take!(chnl)
82+
close(chnl)
83+
error("first")
84+
end
85+
catch ex
86+
@test length(ex.exceptions) == 2
87+
@test unpack(ex.exceptions[1]).msg == "first"
88+
end

0 commit comments

Comments
 (0)