Skip to content

Commit a0f2028

Browse files
committed
Rewrite asyncmap with one task per work item
1 parent b52d2b0 commit a0f2028

File tree

2 files changed

+113
-78
lines changed

2 files changed

+113
-78
lines changed

base/asyncmap.jl

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
using Base.Iterators: Enumerate
44

5+
struct AbortMapException <: Exception end
6+
57
"""
68
asyncmap(f, c...; ntasks=0, batch_size=nothing)
79
@@ -104,30 +106,120 @@ julia> asyncmap(batch_func, 1:5; ntasks=2, batch_size=2)
104106
worker invocation, etc. See [`Threads`](@ref) and `Distributed` for alternatives.
105107
106108
"""
107-
function asyncmap(f, c...; ntasks=0, batch_size=nothing)
108-
return async_usemap(f, c...; ntasks=ntasks, batch_size=batch_size)
109-
end
109+
asyncmap(f, c...; kwargs...) = do_asyncmap(f, c...; kwargs...)
110110

111-
function async_usemap(f, c...; ntasks=0, batch_size=nothing)
112-
ntasks = verify_ntasks(c[1], ntasks)
111+
function do_asyncmap(f, c...; ntasks=0, batch_size=nothing)
112+
ntasks = let n = verify_ntasks(c[1], ntasks)
113+
n isa Number ? () -> n : n
114+
end
113115
batch_size = verify_batch_size(batch_size)
114116

115-
if batch_size !== nothing
116-
exec_func = batch -> begin
117-
# extract the Refs from the input tuple
118-
batch_refs = map(x->x[1], batch)
117+
(tasks, echnl, res) = if batch_size === nothing
118+
do_asyncmap(f, c, ntasks)
119+
else
120+
do_asyncmap(f, c, ntasks, batch_size)
121+
end
119122

120-
# and the args tuple....
121-
batched_args = map(x->x[2], batch)
123+
foreach(_wait, tasks)
124+
close(echnl)
122125

123-
results = f(batched_args)
124-
foreach(x -> (batch_refs[x[1]].x = x[2]), enumerate(results))
125-
end
126+
isready(echnl) &&
127+
throw(CompositeException(TaskFailedException.(echnl)))
128+
129+
if res isa Ref
130+
# scalar case: map(f, x...) = f(x...)
131+
res.x
126132
else
127-
exec_func = (r,args) -> (r.x = f(args...))
133+
map(ref -> ref.x, res)
134+
end
135+
end
136+
137+
mutable struct AsyncmapState
138+
tasks::Vector{Task}
139+
ntasks::Function
140+
concur::Int
141+
echnl::Channel{Task}
142+
end
143+
144+
AsyncmapState(ntasks) =
145+
AsyncmapState([], ntasks, 0, Channel{Task}(typemax(Int)))
146+
147+
function do_asyncmap_task!(f::Function, s::AsyncmapState, xs...)
148+
t = @task try
149+
f(xs...)
150+
catch
151+
put!(s.echnl, current_task())
152+
rethrow()
153+
finally
154+
s.concur -= 1
155+
end
156+
157+
while s.concur >= s.ntasks() && !isready(s.echnl)
158+
yield()
159+
end
160+
161+
isready(s.echnl) && throw(AbortMapException());
162+
163+
s.concur += 1
164+
schedule(t)
165+
push!(s.tasks, t)
166+
end
167+
168+
function do_asyncmap(f, c::Tuple, ntasks::Function)
169+
s = AsyncmapState(ntasks)
170+
171+
res = try
172+
map(c...) do x...
173+
r = Ref{Any}(nothing)
174+
175+
do_asyncmap_task!(s) do
176+
r.x = f(x...)
177+
end
178+
179+
r
180+
end
181+
catch ex
182+
ex isa AbortMapException || rethrow()
128183
end
129-
chnl, err_chnl, worker_tasks = setup_chnl_and_tasks(exec_func, ntasks, batch_size)
130-
return wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...)
184+
185+
s.tasks, s.echnl, res
186+
end
187+
188+
function do_asyncmap(f, c::Tuple, ntasks::Function, batch_size)
189+
s = AsyncmapState(ntasks)
190+
local batch_in
191+
local batch_out
192+
193+
function new_batch()
194+
batch_in = sizehint!(Vector{Any}(), batch_size)
195+
batch_out = sizehint!(Vector{Ref{Any}}(), batch_size)
196+
end
197+
198+
exec_batch(args, res) = for (i, x) in enumerate(f(args))
199+
res[i].x = x
200+
end
201+
202+
new_batch()
203+
res = try
204+
map(c...) do x
205+
r = Ref{Any}(undef)
206+
207+
push!(batch_in, x)
208+
push!(batch_out, r)
209+
210+
if length(batch_in) >= batch_size
211+
do_asyncmap_task!(exec_batch, s, batch_in, batch_out)
212+
new_batch()
213+
end
214+
215+
r
216+
end
217+
catch ex
218+
ex isa AbortMapException || rethrow()
219+
end
220+
do_asyncmap_task!(exec_batch, s, batch_in, batch_out)
221+
222+
s.tasks, s.echnl, res
131223
end
132224

133225
batch_size_err_str(batch_size) = string("batch_size must be specified as a positive integer. batch_size=", batch_size)
@@ -161,63 +253,6 @@ function verify_ntasks(iterable, ntasks)
161253
return ntasks
162254
end
163255

164-
function wrap_n_exec_twice(chnl, err_chnl, worker_tasks, ntasks, exec_func, c...)
165-
# The driver task, creates a Ref object and writes it and the args tuple to
166-
# the communication channel for processing by a free worker task.
167-
push_arg_to_channel = (x...) -> (r=Ref{Any}(nothing); put!(chnl,(r,x));r)
168-
169-
if isa(ntasks, Function)
170-
map_f = (x...) -> begin
171-
# check number of tasks every time, and start one if required.
172-
# number_tasks > optimal_number is fine, the other way around is inefficient.
173-
if length(worker_tasks) < ntasks()
174-
start_worker_task!(worker_tasks, exec_func, chnl, err_chnl)
175-
end
176-
push_arg_to_channel(x...)
177-
end
178-
else
179-
map_f = push_arg_to_channel
180-
end
181-
maptwice(map_f, chnl, err_chnl, worker_tasks, c...)
182-
end
183-
184-
function maptwice(wrapped_f, chnl, err_chnl, worker_tasks, c...)
185-
# first run, returns a collection of Refs
186-
asyncrun = try
187-
map(wrapped_f, c...)
188-
catch ex
189-
# the work channel was closed early and we have the error
190-
ex isa InvalidStateException && isready(err_chnl) || rethrow()
191-
finally
192-
close(chnl)
193-
end
194-
195-
# wait for all worker tasks to finish. We try not to throw
196-
# task errors here because they would be in task creation
197-
# order.
198-
try
199-
@sync foreach(t -> @sync_add(t), worker_tasks)
200-
catch
201-
isready(err_chnl) || rethrow()
202-
finally
203-
close(err_chnl)
204-
end
205-
206-
# throw task errors in chronological order
207-
if isready(err_chnl)
208-
exs = [TaskFailedException(task) for task in err_chnl]
209-
throw(CompositeException(exs))
210-
end
211-
212-
if isa(asyncrun, Ref)
213-
# scalar case
214-
return asyncrun.x
215-
else
216-
# second run, extract values from the Refs and return
217-
return map(ref->ref.x, asyncrun)
218-
end
219-
end
220-
221256
function setup_chnl_and_tasks(exec_func, ntasks, batch_size=nothing)
222257
if isa(ntasks, Function)
223258
nt = ntasks()
@@ -279,7 +314,7 @@ end
279314

280315
# map on a single BitArray returns a BitArray if the mapping function is boolean.
281316
function asyncmap(f, b::BitArray; kwargs...)
282-
b2 = async_usemap(f, b; kwargs...)
317+
b2 = do_asyncmap(f, b; kwargs...)
283318
if eltype(b2) == Bool
284319
return BitArray(b2)
285320
end

test/asyncmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ using Random
66
@test allunique(asyncmap(x->(sleep(1.0);objectid(current_task())), 1:10))
77

88
# num tasks
9-
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:20; ntasks=5))) == 5
9+
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:20; ntasks=5))) == 20
1010

1111
# default num tasks
12-
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:200))) == 100
12+
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:200))) == 200
1313

1414
# ntasks as a function
1515
let nt=0
@@ -18,7 +18,7 @@ let nt=0
1818
# nt_func() will be called initially once and then for every
1919
# iteration
2020
end
21-
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:200; ntasks=nt_func))) == 7
21+
@test length(unique(asyncmap(x->(yield();objectid(current_task())), 1:200; ntasks=nt_func))) == 200
2222

2323
# batch mode tests
2424
let ctr=0

0 commit comments

Comments
 (0)