Skip to content

remove threadedregion and move jl_threading_run to julia #32477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,28 @@ function _threadsfor(iter,lbody)
lidx = iter.args[1] # index
range = iter.args[2]
quote
local threadsfor_fun
let range = $(esc(range))
function threadsfor_fun(onethread=false)
function threadsfor_fun(grain)
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
if onethread
tid = 1
len, rem = lenr, 0
else
tid = threadid()
len, rem = divrem(lenr, nthreads())
end
# divide loop iterations among tasks
ngrains = min(nthreads(), lenr)
len, rem = divrem(lenr, ngrains)
# not enough iterations for all the threads?
if len == 0
if tid > rem
if grain > rem
return
end
len, rem = 1, 0
end
# compute this thread's iterations
f = firstindex(r) + ((tid-1) * len)
f = firstindex(r) + ((grain-1) * len)
l = f + len - 1
# distribute remaining iterations evenly
if rem > 0
if tid <= rem
f = f + (tid-1)
l = l + tid
if grain <= rem
f = f + (grain-1)
l = l + grain
else
f = f + rem
l = l + rem
Expand All @@ -61,17 +55,25 @@ function _threadsfor(iter,lbody)
$(esc(lbody))
end
end
end
if threadid() != 1
# only thread 1 can enter/exit _threadedregion
Base.invokelatest(threadsfor_fun, true)
else
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
threading_run(threadsfor_fun, length(range))
end
nothing
end
end

function threading_run(func, len)
ngrains = min(nthreads(), len)
tasks = Vector{Task}(undef, ngrains)
for grain in 1:ngrains
t = Task(()->func(grain))
t.sticky = false
tasks[grain] = t
schedule(t)
end
Base.sync_end(tasks)
return nothing
end

"""
Threads.@threads

Expand Down
4 changes: 1 addition & 3 deletions src/jl_uv.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,11 @@ JL_DLLEXPORT void jl_uv_req_set_data(uv_req_t *req, void *data) { req->data = da
JL_DLLEXPORT void *jl_uv_handle_data(uv_handle_t *handle) { return handle->data; }
JL_DLLEXPORT void *jl_uv_write_handle(uv_write_t *req) { return req->handle; }

extern volatile unsigned _threadedregion;

JL_DLLEXPORT int jl_process_events(void)
{
jl_ptls_t ptls = jl_get_ptls_states();
uv_loop_t *loop = jl_io_loop;
if (loop && (_threadedregion || ptls->tid == 0)) {
if (loop) {
jl_gc_safepoint_(ptls);
if (jl_mutex_trylock(&jl_uv_mutex)) {
loop->stop_flag = 0;
Expand Down
17 changes: 3 additions & 14 deletions src/partr.c
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ static int may_sleep(jl_ptls_t ptls)
return jl_atomic_load(&sleep_check_state) == sleeping && jl_atomic_load(&ptls->sleep_check_state) == sleeping;
}

extern volatile unsigned _threadedregion;

JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
{
jl_ptls_t ptls = jl_get_ptls_states();
Expand All @@ -413,7 +411,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
}

jl_cpu_pause();
if (sleep_check_after_threshold(&start_cycles) || (!_threadedregion && ptls->tid == 0)) {
if (sleep_check_after_threshold(&start_cycles)) {
if (!sleep_check_now(ptls->tid))
continue;
jl_atomic_store(&ptls->sleep_check_state, sleeping); // acquire sleep-check lock
Expand All @@ -425,14 +423,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
// although none are allowed to create new ones
// outside of threaded regions, all IO is permitted,
// but only on thread 1
int uvlock = 0;
if (_threadedregion) {
uvlock = jl_mutex_trylock(&jl_uv_mutex);
}
else if (ptls->tid == 0) {
uvlock = 1;
JL_UV_LOCK();
}
int uvlock = jl_mutex_trylock(&jl_uv_mutex);
if (uvlock) {
int active = 1;
if (jl_atomic_load(&jl_uv_n_waiters) != 0) {
Expand Down Expand Up @@ -462,9 +453,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
// to the last thread to do an explicit operation,
// which may starve other threads of critical work
}
if (!_threadedregion && active && ptls->tid == 0) {
// thread 0 is the only thread permitted to run the event loop
// so it needs to stay alive
if (active) {
start_cycles = 0;
continue;
}
Expand Down
68 changes: 0 additions & 68 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -475,74 +475,6 @@ void jl_start_threads(void)
uv_barrier_wait(&thread_init_done);
}

unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO

// simple fork/join mode code
JL_DLLEXPORT void jl_threading_run(jl_value_t *func)
{
jl_ptls_t ptls = jl_get_ptls_states();
int8_t gc_state = jl_gc_unsafe_enter(ptls);
size_t world = jl_world_counter;
jl_method_instance_t *mfunc = jl_lookup_generic(&func, 1, jl_int32hash_fast(jl_return_address()), world);
// Ignore constant return value for now.
jl_code_instance_t *fptr = jl_compile_method_internal(mfunc, world);
if (fptr->invoke == jl_fptr_const_return)
return;

size_t nthreads = jl_n_threads;
jl_svec_t *ts = jl_alloc_svec(nthreads);
JL_GC_PUSH1(&ts);
jl_value_t *wait_func = jl_get_global(jl_base_module, jl_symbol("wait"));
jl_value_t *schd_func = jl_get_global(jl_base_module, jl_symbol("schedule"));
// create and schedule all tasks
_threadedregion += 1;
for (int i = 0; i < nthreads; i++) {
jl_value_t *args2[2];
args2[0] = (jl_value_t*)jl_task_type;
args2[1] = func;
jl_task_t *t = (jl_task_t*)jl_apply(args2, 2);
jl_svecset(ts, i, t);
t->sticky = 1;
t->tid = i;
args2[0] = schd_func;
args2[1] = (jl_value_t*)t;
jl_apply(args2, 2);
if (i == 1) {
// let threads know work is coming (optimistic)
jl_wakeup_thread(-1);
}
}
if (nthreads > 2) {
// let threads know work is ready (guaranteed)
jl_wakeup_thread(-1);
}
// join with all tasks
JL_TRY {
for (int i = 0; i < nthreads; i++) {
jl_value_t *t = jl_svecref(ts, i);
jl_value_t *args[2] = { wait_func, t };
jl_apply(args, 2);
}
}
JL_CATCH {
_threadedregion -= 1;
jl_wake_libuv();
JL_UV_LOCK();
JL_UV_UNLOCK();
jl_rethrow();
}
// make sure no threads are sitting in the event loop
_threadedregion -= 1;
jl_wake_libuv();
// make sure no more callbacks will run while user code continues
// outside thread region and might touch an I/O object.
JL_UV_LOCK();
JL_UV_UNLOCK();
JL_GC_POP();
jl_gc_unsafe_leave(ptls, gc_state);
}


// Make gc alignment available for threading
// see threads.jl alignment
JL_DLLEXPORT int jl_alignment(size_t sz)
Expand Down
6 changes: 2 additions & 4 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,9 @@ for period in (0.06, Dates.Millisecond(60))
t = Timer(period)
wait(t)
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
wait(c)
sleep(period)
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
end))
wait(c)
notify(c)
Expand Down Expand Up @@ -700,8 +698,8 @@ function _atthreads_with_error(a, err)
end
a
end
@test_throws TaskFailedException _atthreads_with_error(zeros(nthreads()), true)
@test_throws CompositeException _atthreads_with_error(zeros(nthreads()), true)
let a = zeros(nthreads())
_atthreads_with_error(a, false)
@test a == [1:nthreads();]
@test all(n->(1 <= n <= nthreads()), a)
end