Skip to content

Commit ac1d45e

Browse files
authored
Fix bug (#12)
* simpler ArrayStyle * hoist buffer out of benchmark * rm debug code, restore codepath
1 parent ba04628 commit ac1d45e

File tree

6 files changed

+74
-53
lines changed

6 files changed

+74
-53
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,18 @@ For a less-toy example, in `test/flux.jl` we test inference over a Flux model:
6060

6161
```julia
6262
# Baseline: Array
63-
infer!(predictions, model, data): 0.457247 seconds (8.02 k allocations: 370.796 MiB, 6.10% gc time)
63+
infer!(b, predictions, model, data): 0.499735 seconds (8.05 k allocations: 306.796 MiB, 6.47% gc time)
6464
# Baseline: StrideArray
6565
stride_data = StrideArray.(data)
66-
infer!(predictions, model, stride_data): 0.336535 seconds (8.05 k allocations: 370.796 MiB, 6.20% gc time)
66+
infer!(b, predictions, model, stride_data): 0.364180 seconds (8.05 k allocations: 306.796 MiB, 8.32% gc time)
6767
# Using AllocArray:
6868
alloc_data = AllocArray.(data)
69-
infer!(predictions, model, alloc_data): 0.318736 seconds (13.35 k allocations: 67.225 MiB)
69+
infer!(b, predictions, model, alloc_data): 0.351953 seconds (13.60 k allocations: 3.221 MiB)
7070
checked_alloc_data = CheckedAllocArray.(data)
71-
infer!(predictions, model, checked_alloc_data): 23.673344 seconds (26.15 k allocations: 67.773 MiB)
71+
infer!(b, predictions, model, checked_alloc_data): 15.522897 seconds (25.54 k allocations: 3.742 MiB)
7272
```
7373

74-
We can see in this example, we got much less allocation (and no GC time), and similar runtime. By running larger examples, the gap in allocations can be much larger; here we use a 64 MiB buffer that we allocate each `infer!` call, which accounts for most of the memory usage.
74+
We can see in this example, we got 100x less allocation (and no GC time), and similar runtime, for `AllocArray`s. We can see `CheckedAllocArrays` are far slower here.
7575

7676
## Design notes
7777

src/AllocArray.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ Base.size(a::AllocArray) = size(getfield(a, :arr))
4444
Base.IndexStyle(::Type{<:AllocArray{T,N,Arr}}) where {T,N,Arr} = Base.IndexStyle(Arr)
4545

4646
# used only by broadcasting?
47-
function Base.similar(::Type{AllocArray{T,N,Arr}}, dims::Dims) where {T,N,Arr}
48-
return alloc_similar(CURRENT_ALLOCATOR[], AllocArray{T,N,Arr}, dims)
47+
function Base.similar(::Type{<:AllocArray{T}}, dims::Dims) where {T}
48+
return alloc_similar(CURRENT_ALLOCATOR[], AllocArray{T}, dims)
4949
end
5050

5151
function Base.similar(a::AllocArray, ::Type{T}, dims::Dims) where {T}
@@ -56,13 +56,13 @@ end
5656
##### Broadcasting
5757
#####
5858

59-
function Base.BroadcastStyle(::Type{AllocArray{T,N,Arr}}) where {T,N,Arr}
60-
return Broadcast.ArrayStyle{AllocArray{T,N,Arr}}()
59+
function Base.BroadcastStyle(::Type{<:AllocArray})
60+
return ArrayStyle{AllocArray}()
6161
end
6262

63-
function Base.similar(bc::Broadcasted{ArrayStyle{AllocArray{T,N,Arr}}},
64-
::Type{ElType}) where {T,N,Arr,ElType}
65-
return similar(AllocArray{T,N,Arr}, axes(bc))
63+
function Base.similar(bc::Broadcasted{ArrayStyle{AllocArray}},
64+
::Type{T}) where {T}
65+
return similar(AllocArray{T}, axes(bc))
6666
end
6767

6868
#####

src/CheckedAllocArray.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ end
121121
Base.IndexStyle(::Type{<:CheckedAllocArray{T,N,Arr}}) where {T,N,Arr} = Base.IndexStyle(Arr)
122122

123123
# used only by broadcasting?
124-
function Base.similar(::Type{CheckedAllocArray{T,N,Arr}}, dims::Dims) where {T,N,Arr}
125-
return alloc_similar(CURRENT_ALLOCATOR[], CheckedAllocArray{T,N,Arr}, dims)
124+
function Base.similar(::Type{<:CheckedAllocArray{T}}, dims::Dims) where {T}
125+
return alloc_similar(CURRENT_ALLOCATOR[], CheckedAllocArray{T}, dims)
126126
end
127127

128128
function Base.similar(a::CheckedAllocArray, ::Type{T}, dims::Dims) where {T}
@@ -133,13 +133,13 @@ end
133133
##### Broadcasting
134134
#####
135135

136-
function Base.BroadcastStyle(::Type{CheckedAllocArray{T,N,Arr}}) where {T,N,Arr}
137-
return Broadcast.ArrayStyle{CheckedAllocArray{T,N,Arr}}()
136+
function Base.BroadcastStyle(::Type{<:CheckedAllocArray})
137+
return ArrayStyle{CheckedAllocArray}()
138138
end
139139

140-
function Base.similar(bc::Broadcasted{ArrayStyle{CheckedAllocArray{T,N,Arr}}},
141-
::Type{ElType}) where {T,N,Arr,ElType}
142-
return similar(CheckedAllocArray{T,N,Arr}, axes(bc))::CheckedAllocArray
140+
function Base.similar(bc::Broadcasted{ArrayStyle{CheckedAllocArray}},
141+
::Type{T}) where {T}
142+
return similar(CheckedAllocArray{T}, axes(bc))::CheckedAllocArray
143143
end
144144

145145
#####

src/alloc_interface.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
66
Allocators need to subtype `Allocator` and implement two methods of `alloc_similar`:
77
8-
- `AllocArrays.alloc_similar(::Allocator, arr, ::Type{T}, dims::Dims)`
9-
- `AllocArrays.alloc_similar(::Allocator, ::Type{Arr}, dims::Dims) where {Arr<:AbstractArray}`
8+
- `AllocArrays.alloc_similar(::MyAllocator, a::AllocArray, ::Type{T}, dims::Dims)`
9+
- `AllocArrays.alloc_similar(::MyAllocator, ::Type{<:AllocArray{T}}, dims::Dims) where {T}`
1010
11-
where the latter is used by broadcasting.
11+
to support `AllocArrays`, (which should each return an `AllocArray`) and likewise
12+
13+
- `AllocArrays.alloc_similar(::MyAllocator, a::CheckedAllocArray, ::Type{T}, dims::Dims)`
14+
- `AllocArrays.alloc_similar(::MyAllocator, ::Type{<:CheckedAllocArray{T}}, dims::Dims) where {T}`
15+
16+
which should each return a `CheckedAllocArray`.
1217
"""
1318
abstract type Allocator end
1419

@@ -38,20 +43,22 @@ function alloc_similar(::DefaultAllocator, ::AllocArray, ::Type{T}, dims::Dims)
3843
return AllocArray(similar(Array{T}, dims))
3944
end
4045

41-
function alloc_similar(::DefaultAllocator, ::Type{AllocArray{T,N,Arr}},
42-
dims::Dims) where {T, N, Arr}
43-
return AllocArray(similar(Arr, dims))
44-
end
45-
46-
function alloc_similar(::DefaultAllocator, ::CheckedAllocArray, ::Type{T}, dims::Dims) where {T}
47-
return CheckedAllocArray(similar(Array{T}, dims))
46+
function alloc_similar(::DefaultAllocator, ::Type{<:AllocArray{T}},
47+
dims::Dims) where {T}
48+
return AllocArray(similar(Array{T}, dims))
4849
end
4950

50-
function alloc_similar(::DefaultAllocator, ::Type{CheckedAllocArray{T,N,Arr}},
51-
dims::Dims) where {T, N, Arr}
51+
function alloc_similar(D::DefaultAllocator, c::CheckedAllocArray, ::Type{T},
52+
dims::Dims) where {T}
5253
# We know the memory is valid since it was allocated with the
5354
# default allocator
54-
return CheckedAllocArray(similar(Arr, dims), MemValid(true))
55+
a = @lock(c, alloc_similar(D, _get_inner(c), T, dims))
56+
return CheckedAllocArray(a, MemValid(true))
57+
end
58+
59+
function alloc_similar(D::DefaultAllocator, ::Type{<:CheckedAllocArray{T}},
60+
dims::Dims) where {T}
61+
return CheckedAllocArray(alloc_similar(D, AllocArray{T}, dims), MemValid(true))
5562
end
5663

5764
#####
@@ -125,12 +132,14 @@ function reset!(B::UncheckedBumperAllocator)
125132
return nothing
126133
end
127134

128-
function alloc_similar(B::UncheckedBumperAllocator, ::AllocArray, ::Type{T}, dims::Dims) where {T}
135+
function alloc_similar(B::UncheckedBumperAllocator, ::AllocArray, ::Type{T},
136+
dims::Dims) where {T}
129137
inner = Bumper.alloc(T, B.buf, dims...)
130138
return AllocArray(inner)
131139
end
132140

133-
function alloc_similar(B::UncheckedBumperAllocator, ::Type{AllocArray{T,N,Arr}}, dims::Dims) where {T, N, Arr}
141+
function alloc_similar(B::UncheckedBumperAllocator, ::Type{<:AllocArray{T}},
142+
dims::Dims) where {T}
134143
inner = Bumper.alloc(T, B.buf, dims...)
135144
return AllocArray(inner)
136145
end
@@ -224,10 +233,10 @@ function alloc_similar(B::BumperAllocator, c::CheckedAllocArray, ::Type{T},
224233
end
225234
end
226235

227-
function alloc_similar(B::BumperAllocator, ::Type{CheckedAllocArray{T,N,Arr}},
228-
dims::Dims) where {T,N,Arr}
236+
function alloc_similar(B::BumperAllocator, ::Type{<:CheckedAllocArray{T}},
237+
dims::Dims) where {T}
229238
@lock B begin
230-
inner = alloc_similar(B.bumper, Arr, dims)
239+
inner = alloc_similar(B.bumper, AllocArray{T}, dims)
231240
valid = MemValid(true)
232241
push!(B.mems, valid)
233242
return CheckedAllocArray(inner, valid)
@@ -239,9 +248,9 @@ end
239248
# If we have a `BumperAllocator` and are asked to allocate an unchecked array
240249
# then we can do that by dispatching to the inner bumper. We will still
241250
# get the lock for concurrency-safety.
242-
function alloc_similar(B::BumperAllocator, ::Type{AllocArray{T,N,Arr}},
243-
dims::Dims) where {T,N,Arr}
244-
return @lock(B, alloc_similar(B.bumper, AllocArray{T,N,Arr}, dims))
251+
function alloc_similar(B::BumperAllocator, ::Type{<:AllocArray{T}},
252+
dims::Dims) where {T}
253+
return @lock(B, alloc_similar(B.bumper, AllocArray{T}, dims))
245254
end
246255

247256
function alloc_similar(B::BumperAllocator, a::AllocArray, ::Type{T},

test/flux.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ end
8080
# Our model acts on input just by applying the chain.
8181
(m::DigitsModel)(x) = m.chain(x)
8282

83-
function infer!(predictions, model, data)
84-
b = BumperAllocator(2^26) # 64 MiB
83+
function infer!(b, predictions, model, data)
8584
# Here we use a locked bumper for thread-safety, since NNlib multithreads
8685
# some of it's functions. However we are sure to only deallocate outside of the threaded region. (All concurrency occurs within the `model` call itself).
8786
with_allocator(b) do
@@ -96,6 +95,8 @@ end
9695
@testset "More complicated model" begin
9796
model = DigitsModel()
9897

98+
b = BumperAllocator(2^26) # 64 MiB
99+
99100
# Setup some fake data
100101
N = 1_000
101102
data_arr = rand(Float32, 28, 28, N)
@@ -112,32 +113,35 @@ end
112113
checked_alloc_data = CheckedAllocArray.(data)
113114

114115
preds_data = fresh_predictions()
115-
infer!(preds_data, model, data)
116+
infer!(b, preds_data, model, data)
116117

117118
preds_alloc = fresh_predictions()
118-
infer!(preds_alloc, model, alloc_data)
119+
infer!(b, preds_alloc, model, alloc_data)
119120

120121
preds_checked_alloc = fresh_predictions()
121-
infer!(preds_checked_alloc, model, checked_alloc_data)
122+
infer!(b, preds_checked_alloc, model, checked_alloc_data)
122123

123124
preds_stride = fresh_predictions()
124125
stride_data = StrideArray.(data)
125-
infer!(preds_stride, model, stride_data)
126+
infer!(b, preds_stride, model, stride_data)
126127

127128
@test preds_data preds_alloc
128129
@test preds_data preds_stride
129130
@test preds_data preds_checked_alloc
130131

131132
predictions = fresh_predictions()
132-
@showtime infer!(predictions, model, data)
133-
@showtime infer!(predictions, model, stride_data)
134-
@showtime infer!(predictions, model, alloc_data)
135-
@showtime infer!(predictions, model, checked_alloc_data)
133+
@showtime infer!(b, predictions, model, data)
134+
@showtime infer!(b, predictions, model, stride_data)
135+
@showtime infer!(b, predictions, model, alloc_data)
136+
@showtime infer!(b, predictions, model, checked_alloc_data)
136137

137138
# Note: for max perf, consider
138-
# (using Functors)
139+
# using Functors
139140
# model = fmap(AllocArray ∘ PtrArray, model; exclude = x -> x isa AbstractArray)
140-
# and `alloc_data = AllocArray.(PtrArray.(data))`
141+
# alloc_data = AllocArray.(PtrArray.(data))
142+
# @showtime infer!(b, predictions, model, alloc_data)
143+
# @showtime infer!(b, predictions, model, alloc_data)
141144
# Together, that ensure everything is an `AllocArray(PtrArray(...))`
142-
# This seems to help although not a huge amount.
145+
# This seems to help with runtime although not a huge amount,
146+
# and doesn't really help with allocations.
143147
end

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,12 @@ end
3939

4040
include("flux.jl")
4141
include("checked.jl")
42+
43+
# Bug reported here:
44+
# https://julialang.zulipchat.com/#narrow/stream/137791-general/topic/AllocArrays.2Ejl/near/398698500
45+
a = AllocArray(1:4)
46+
@test a[1:2] .+ a[3:4]' isa AllocArray
47+
48+
a = CheckedAllocArray(1:4)
49+
@test a[1:2] .+ a[3:4]' isa CheckedAllocArray
4250
end

0 commit comments

Comments
 (0)