Skip to content

Commit 07725da

Browse files
authored
Branch on Bool alpha in scaling mul! (#1286)
By dealing with the `alpha == 0` case separately, we ensure that if `alpha::Bool`, it must be `true`. This reduces the branches in `_lscale_add` from 4 to 2 in the common case of 3-argument `mul!`. This leads to a latency reduction, as each branch has to compile a different broadcast expression, and we currently compile four but use only one. Primarily, this PR leads to a reduction in allocations. ```julia julia> using LinearAlgebra julia> v = 1:4; w = similar(v); julia> @time mul!(w, 1, v); 0.171120 seconds (1.04 M allocations: 52.799 MiB, 99.98% compilation time) # nightly 0.163178 seconds (702.63 k allocations: 35.533 MiB, 99.98% compilation time) # this PR ``` Something similar usually doesn't lead to a big gain in the `_rscale_add` method, as `s * alpha` often has the same type as `s`, and therefore the branches on `alpha` compile the same code.
1 parent 61e444d commit 07725da

File tree

2 files changed

+58
-23
lines changed

2 files changed

+58
-23
lines changed

src/generic.jl

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -202,24 +202,35 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N
202202
generic_mul!(C, s, X, alpha, beta)
203203
@inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
204204
if axes(C) == axes(X)
205-
if isone(alpha)
206-
if iszero(beta)
207-
@. C = s * X
208-
else
209-
@. C = s * X + C * beta
210-
end
211-
else
212-
if iszero(beta)
213-
@. C = s * X * alpha
214-
else
215-
@. C = s * X * alpha + C * beta
216-
end
217-
end
205+
iszero(alpha) && return _rmul_or_fill!(C, beta)
206+
_lscale_add_nonzeroalpha!(C, s, X, alpha, beta)
218207
else
219208
generic_mul!(C, s, X, alpha, beta)
220209
end
221210
return C
222211
end
212+
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
213+
if isone(alpha)
214+
# since alpha is unused, we might as well set to `true` to avoid recompiling
215+
# the branch if an `alpha` of a different type is used
216+
_lscale_add_nonzeroalpha!(C, s, X, true, beta)
217+
else
218+
if iszero(beta)
219+
@. C = s * X * alpha
220+
else
221+
@. C = s * X * alpha + C * beta
222+
end
223+
end
224+
C
225+
end
226+
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Bool, beta::Number)
227+
if iszero(beta)
228+
@. C = s * X
229+
else
230+
@. C = s * X + C * beta
231+
end
232+
C
233+
end
223234
@inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) =
224235
_rscale_add!(C, X, s, alpha, beta)
225236

@@ -228,24 +239,26 @@ _rscale_add!(C::StridedArray, X::StridedArray, s::Number, alpha::Number, beta::N
228239
@inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number)
229240
if axes(C) == axes(X)
230241
if isone(alpha)
231-
if iszero(beta)
232-
@. C = X * s
233-
else
234-
@. C = X * s + C * beta
235-
end
242+
# since alpha is unused, we might as well ignore it in this branch.
243+
# This avoids recompiling the branch if an `alpha` of a different type is used
244+
_rscale_add_alphaisone!(C, X, s, beta)
236245
else
237246
s_alpha = s * alpha
238-
if iszero(beta)
239-
@. C = X * s_alpha
240-
else
241-
@. C = X * s_alpha + C * beta
242-
end
247+
_rscale_add_alphaisone!(C, X, s_alpha, beta)
243248
end
244249
else
245250
generic_mul!(C, X, s, alpha, beta)
246251
end
247252
return C
248253
end
254+
function _rscale_add_alphaisone!(C::AbstractArray, X::AbstractArray, s::Number, beta::Number)
255+
if iszero(beta)
256+
@. C = X * s
257+
else
258+
@. C = X * s + C * beta
259+
end
260+
C
261+
end
249262

250263
# For better performance when input and output are the same array
251264
# See https://github.com/JuliaLang/julia/issues/8415#issuecomment-56608729

test/generic.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,4 +857,26 @@ end
857857
end
858858
end
859859

860+
@testset "scaling mul" begin
861+
v = 1:4
862+
w = similar(v)
863+
@test mul!(w, 2, v) == 2v
864+
@test mul!(w, v, 2) == 2v
865+
# 5-arg equivalent to the 3-arg method, but with non-Bool alpha
866+
@test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v
867+
@test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v
868+
# 5-arg tests with alpha::Bool
869+
@test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v
870+
@test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v
871+
@test mul!(copy!(similar(v), v), 2, v, false, 2) == 2v
872+
@test mul!(copy!(similar(v), v), v, 2, false, 2) == 2v
873+
# 5-arg tests
874+
@test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v
875+
@test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v
876+
@test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v
877+
@test mul!(copy!(similar(v), v), v, 2, 2, 3) == 7v
878+
@test mul!(copy!(similar(v), v), 2, v, 2, 0) == 4v
879+
@test mul!(copy!(similar(v), v), v, 2, 2, 0) == 4v
880+
end
881+
860882
end # module TestGeneric

0 commit comments

Comments
 (0)