-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Batch Channel Operations #56473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Batch Channel Operations #56473
Changes from all commits
83461ce
40e28dc
934f05d
7955d0b
405e0c6
2ced667
320d948
b1ddd99
2955539
4069be4
6706d6d
3c8b1b4
8feea82
6a037ce
6e69939
201cad2
e1e7d26
3e70004
fdf1839
0b91367
8203cd8
e518495
5eca478
0ac0512
69f2d7a
c2bec3e
487491e
b696077
825b164
93f40c4
f717594
5326ca0
5c812ff
8752eec
32fcba5
0c853fe
d5512f6
e4b247f
1ef4ddb
ccc6afd
26ed2ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,6 +189,54 @@ function check_channel_state(c::Channel) | |
throw(closed_exception()) | ||
end | ||
end | ||
|
||
function _ignore_closed_err(f, args...; kwargs...) | ||
# This function will run `f()` and return `true` if it throws an `InvalidStateException` because | ||
# a channel is closed, and `false` otherwise. | ||
try | ||
f(args...; kwargs...) | ||
return false | ||
catch e | ||
if isa(e, InvalidStateException) && e.state === :closed | ||
return true | ||
else | ||
rethrow() | ||
end | ||
end | ||
end | ||
|
||
function _positive_int(x::T, purpose::String) where {T<:Union{AbstractFloat, Integer}} | ||
if T <: AbstractFloat | ||
if x == Inf | ||
return typemax(Int) | ||
end | ||
x = try | ||
convert(Int, x) | ||
catch e | ||
!(e isa InexactError) && rethrow() | ||
-1 | ||
end | ||
end | ||
if x < 0 | ||
throw(ArgumentError("$purpose must be either 0, a positive integer or Inf")) | ||
end | ||
return x | ||
end | ||
|
||
function _wait_for_space(c::Channel) | ||
while length(c.data) == c.sz_max | ||
check_channel_state(c) | ||
wait(c.cond_put) | ||
end | ||
end | ||
|
||
function _wait_for_data(c::Channel, min_len=1) | ||
while length(c.data) < min_len | ||
check_channel_state(c) | ||
wait(c.cond_take) | ||
end | ||
end | ||
|
||
""" | ||
close(c::Channel[, excp::Exception]) | ||
|
||
|
@@ -415,10 +463,7 @@ function put_buffered(c::Channel, v) | |
# Increment channel n_avail eagerly (before push!) to count data in the | ||
# buffer as well as offers from tasks which are blocked in wait(). | ||
_increment_n_avail(c, 1) | ||
while length(c.data) == c.sz_max | ||
check_channel_state(c) | ||
wait(c.cond_put) | ||
end | ||
_wait_for_space(c) | ||
check_channel_state(c) | ||
push!(c.data, v) | ||
did_buffer = true | ||
|
@@ -454,6 +499,144 @@ function put_unbuffered(c::Channel, v) | |
return v | ||
end | ||
|
||
""" | ||
append!(c::Channel, iter) | ||
|
||
Append all items in `iter` to the channel `c`. If the channel is buffered, this operation requires | ||
fewer `lock` operations than individual `put!`s. If the length of the iterator | ||
is greater than the available buffer space, the operation will block until enough elements | ||
are taken from the channel to make space for the new elements. | ||
|
||
# Examples | ||
|
||
```jldoctest | ||
julia> c = Channel(3); | ||
|
||
julia> append!(c, 1:3); | ||
|
||
julia> take!(c) | ||
1 | ||
|
||
julia> take!(c) | ||
2 | ||
|
||
julia> take!(c) | ||
3 | ||
``` | ||
|
||
!!! compat "Julia 1.12" | ||
Requires at least Julia 1.12. | ||
""" | ||
function append!(c::Channel, iter::T) where {T} | ||
if IteratorSize(iter) isa Union{HasLength, HasShape} | ||
len = length(iter) | ||
# short circuit for small iters | ||
if len == 0 | ||
return c | ||
elseif len == 1 | ||
put!(c, @inbounds vec[begin]) | ||
return c | ||
end | ||
end | ||
|
||
return isbuffered(c) ? append_buffered(c, iter) : append_unbuffered(c, iter) | ||
end | ||
function append!(c_dst::Channel, c_src::Channel{T}) where {T} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be worth checking that |
||
# if either channel has a buffer size of 0 or 1, fallback to a naive `put!` loop | ||
if min(c_dst.sz_max, c_src.sz_max) < 2 | ||
return isbuffered(c_dst) ? append_buffered(c_dst, c_src) : append_unbuffered(c_dst, c_src) | ||
end | ||
i = 0 | ||
lock(c_src) | ||
try | ||
while isopen(c_src) || isready(c_src) | ||
# wait for data to become available | ||
_ignore_closed_err(_wait_for_data, c_src) | ||
|
||
lock(c_dst) | ||
try | ||
_wait_for_space(c_dst) | ||
|
||
# append as many as possible to c1 | ||
n = min(c_dst.sz_max - length(c_dst.data), length(c_src.data)) | ||
_increment_n_avail(c_dst, n) | ||
c1_len_og = length(c_dst.data) | ||
|
||
resize!(c_dst.data, c1_len_og + n) | ||
copyto!(c_dst.data, c1_len_og + 1, c_src.data, 1, n) | ||
deleteat!(c_src.data, 1:n) | ||
|
||
_increment_n_avail(c_src, -n) | ||
notify(c_dst.cond_take, nothing, true, false) | ||
foreach(_ -> notify(c_src.cond_put, nothing, true, false), 1:n) | ||
finally | ||
unlock(c_dst) | ||
end | ||
end | ||
finally | ||
unlock(c_src) | ||
end | ||
return c_dst | ||
end | ||
|
||
function append_unbuffered(c::Channel, iter) | ||
for v in iter | ||
put!(c, v) | ||
end | ||
return c | ||
end | ||
|
||
function append_buffered(c::Channel{T}, iter::I) where {T,I} | ||
mrufsvold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
converted_items = Iterators.Stateful(Iterators.map(x -> convert(T, x), iter)) | ||
|
||
has_length = IteratorSize(iter) isa HasLength | ||
elements_to_add = if has_length | ||
# when we know the length, we can eagerly increment for all items | ||
l = length(iter) | ||
_increment_n_avail(c, l) | ||
l | ||
else | ||
0 | ||
end | ||
|
||
while !isempty(converted_items) | ||
lock(c) | ||
try | ||
_wait_for_space(c) | ||
|
||
# Grab a chunk of items that will fit in the channel's buffer | ||
available_space = c.sz_max - length(c.data) | ||
chunk = Iterators.take(converted_items, available_space) | ||
|
||
# for iterators without length, we increment the available items one item at a time | ||
if !has_length | ||
mrufsvold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
chunk = Iterators.map(chunk) do x | ||
elements_to_add += 1 | ||
_increment_n_avail(c, 1) | ||
x | ||
end | ||
end | ||
|
||
check_channel_state(c) | ||
append!(c.data, chunk) | ||
|
||
# We successfully added chunk, so we can remove the elements from the count | ||
elements_to_add -= available_space | ||
|
||
# notify all, since some of the waiters may be on a "fetch" call. | ||
notify(c.cond_take, nothing, true, false) | ||
catch | ||
# Decrement the available items if this task had an exception before pushing the | ||
# item to the buffer (e.g., during `wait(c.cond_put)`): | ||
elements_to_add > 0 && _increment_n_avail(c, -elements_to_add) | ||
rethrow() | ||
finally | ||
unlock(c) | ||
end | ||
end | ||
return c | ||
end | ||
|
||
""" | ||
fetch(c::Channel) | ||
|
||
|
@@ -482,16 +665,13 @@ fetch(c::Channel) = isbuffered(c) ? fetch_buffered(c) : fetch_unbuffered(c) | |
function fetch_buffered(c::Channel) | ||
lock(c) | ||
try | ||
while isempty(c.data) | ||
check_channel_state(c) | ||
wait(c.cond_take) | ||
end | ||
_wait_for_data(c) | ||
return c.data[1] | ||
finally | ||
unlock(c) | ||
end | ||
end | ||
fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel.")) | ||
fetch_unbuffered(::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel.")) | ||
|
||
|
||
""" | ||
|
@@ -528,10 +708,7 @@ take!(c::Channel) = isbuffered(c) ? take_buffered(c) : take_unbuffered(c) | |
function take_buffered(c::Channel) | ||
lock(c) | ||
try | ||
while isempty(c.data) | ||
check_channel_state(c) | ||
wait(c.cond_take) | ||
end | ||
_wait_for_data(c) | ||
v = popfirst!(c.data) | ||
_increment_n_avail(c, -1) | ||
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!. | ||
|
@@ -553,6 +730,134 @@ function take_unbuffered(c::Channel{T}) where T | |
end | ||
end | ||
|
||
""" | ||
take!(c::Channel, n::Integer[, buffer::AbstractVector]) | ||
|
||
Take at most `n` items from the [`Channel`](@ref) `c` and return them in a `Vector`. `n` | ||
must be 0, a positive integer, or `Inf`. If `buffer` is provided, it will be used to store | ||
the result. If the channel closes before | ||
`n` items are taken, the result will contain only the items that were available. For | ||
buffered channels, this operation requires fewer `lock` operations than individual `take!(c)`s. | ||
|
||
# Examples | ||
|
||
```jldoctest | ||
julia> c = Channel(3); | ||
|
||
julia> append!(c, 1:3); | ||
|
||
julia> take!(c, 3, zeros(Int, 3)) | ||
3-element Vector{Int64}: | ||
1 | ||
2 | ||
3 | ||
``` | ||
|
||
!!! compat "Julia 1.12" | ||
Requires at least Julia 1.12. | ||
""" | ||
function take!(c::Channel{T}, n::N, b::AbstractVector = Vector{T}()) where {T, N<:Union{Integer, AbstractFloat}} | ||
n = _positive_int(n, "Number of elements to take") | ||
return _take!(c, n, b) | ||
end | ||
function _take!(c::Channel{T}, n::Integer, buffer::AbstractVector) where {T} | ||
# for small-ish n, we can avoid resizing the buffer as we push elements | ||
if n < 2^16 | ||
mrufsvold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
resize!(buffer, n) | ||
end | ||
|
||
buffered = isbuffered(c) | ||
# short circuit for small n | ||
if n == 0 | ||
return buffer | ||
elseif n == 1 | ||
is_closed = _ignore_closed_err() do | ||
buffer[begin] = take!(c) | ||
end | ||
# if we caught a :closed error when trying to get the element, we need to clear the | ||
# buffer before returning | ||
is_closed && empty!(buffer) | ||
return buffer | ||
end | ||
return buffered ? take_buffered(c, n, buffer) : take_unbuffered(c, n, buffer) | ||
end | ||
|
||
function _resize_buffer(buffer, n, resize_count, n_to_take=1) | ||
# As we take elements from the channel, we may need to resize the buffer to accommodate them | ||
# for small n, we resize once at the beginning. But for larger n, we resize the buffer as needed. | ||
# We resize exponentially larger each time until we reach 2^16, at which point we resize by 2^16. | ||
len = length(buffer) | ||
needed_space = max(n_to_take, 2^resize_count) | ||
resize!(buffer, min(len + needed_space, n)) | ||
resize_count >= 16 && return resize_count | ||
return resize_count + 1 | ||
end | ||
|
||
function _rm_unused_slots(buffer, last_idx) | ||
if last_idx <= lastindex(buffer) | ||
deleteat!(buffer, last_idx+1:lastindex(buffer)) | ||
end | ||
end | ||
|
||
function take_buffered(c::Channel{T}, n, buffer::AbstractVector) where {T} | ||
elements_taken = 0 # number of elements taken so far | ||
idx1 = firstindex(buffer) | ||
target_buffer_len = min(n, c.sz_max) | ||
resize_count = 1 | ||
|
||
lock(c) | ||
try | ||
while elements_taken < n && (isopen(c) || isready(c)) | ||
# wait until the channel has at least min_n elements or is full | ||
_ignore_closed_err(_wait_for_data, c, target_buffer_len) | ||
# take as many elements as possible from the buffer | ||
n_to_take = min(n - elements_taken, length(c.data)) | ||
idx_start = idx1 + elements_taken | ||
idx_end = idx_start + n_to_take - 1 | ||
if idx_end > lastindex(buffer) | ||
resize_count = _resize_buffer(buffer, n, resize_count, n_to_take) | ||
end | ||
# since idx_start/end are both created relative to `firstindex(buffer)`, they are safe to use | ||
# as indices for the buffer | ||
@views copy!(buffer[idx_start:idx_end], c.data[1:n_to_take]) | ||
deleteat!(c.data, 1:n_to_take) | ||
_increment_n_avail(c, -n_to_take) | ||
foreach(_ -> notify(c.cond_put, nothing, true, false), 1:n_to_take) | ||
mrufsvold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elements_taken += n_to_take | ||
end | ||
|
||
_rm_unused_slots(buffer, firstindex(buffer)+elements_taken-1) | ||
return buffer | ||
finally | ||
unlock(c) | ||
end | ||
end | ||
|
||
function take_unbuffered(c::Channel, n, buffer::AbstractVector) | ||
idx1 = firstindex(buffer) | ||
last_idx = idx1 - 1 | ||
resize_count = 1 | ||
lock(c) | ||
try | ||
for i in idx1:(idx1+n-1) | ||
if i > lastindex(buffer) | ||
resize_count = _resize_buffer(buffer, n, resize_count) | ||
end | ||
_ignore_closed_err() do | ||
buffer[i] = take_unbuffered(c) | ||
end && break | ||
last_idx = i | ||
end | ||
|
||
_rm_unused_slots(buffer, last_idx) | ||
return buffer | ||
finally | ||
unlock(c) | ||
end | ||
end | ||
|
||
collect(c::Channel) = take!(c, typemax(UInt)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, However, this would change the behavior of parallel |
||
|
||
""" | ||
isready(c::Channel) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, calling
Channel(1.5)
throws an InexactError trying to convert to Int. However, the ArgumentError text already captures the real issue (size must be 0, a positive int, or Inf), so I think the additional error type is unnecessary and a bit confusing in the stacktrace (it doesn't point to the issue being with Channel).So I wrote this check and I'm using it to check to make sure
n
is always positive. If people agree with this design choice, it could be extended to the Channel constructor in the future.