Skip to content

Commit 2771f0a

Browse files
committed
Revert "Provide 5-arg mul! JuliaGPU#634"
This reverts commit 5542a9e, reversing changes made to 52c4664.
1 parent fea5135 commit 2771f0a

File tree

2 files changed

+32
-99
lines changed

2 files changed

+32
-99
lines changed

src/blas/linalg.jl

Lines changed: 32 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -72,31 +72,10 @@ function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
7272
gemv!(tA, alpha, A, x, beta, y)
7373
end
7474

75-
function promote_alpha_beta(a, b, ::Type{T}) where {T}
76-
a_prom, b_prom = promote(a, b, zero(T))
77-
a_prom, b_prom
78-
end
79-
80-
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
81-
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
82-
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
83-
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
84-
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasReal =
85-
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
86-
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex =
87-
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)
88-
89-
# Fix Julia 1.3.0 ambiguities... they're fixed in 1.3.1 thanks to https://github.com/JuliaLang/julia/pull/33743
90-
@static if VERSION === v"1.3.0"
91-
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
92-
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
93-
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
94-
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
95-
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
96-
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
97-
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
98-
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)
99-
end
75+
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'N', A, B)
76+
LinearAlgebra.lmul!(Y::CuVector{T}, A::Transpose{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
77+
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
78+
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(Y, 'C', A.parent, B)
10079

10180
# TRSV
10281

@@ -177,66 +156,34 @@ function gemm_wrapper!(C::CuVecOrMat{T}, tA::Char, tB::Char,
177156
end
178157

179158
# Mutating
180-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Number, b::Number) where T<:CublasFloat =
181-
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
182-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasFloat =
183-
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
184-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
185-
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
186-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
187-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
188-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasReal =
189-
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
190-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasComplex =
191-
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
192-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
193-
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
194-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
195-
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
196-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
197-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
198-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
199-
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
200-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
201-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
202-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
203-
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
204-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
205-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
206-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T <: CublasComplex =
207-
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
208-
209-
# Fix Julia 1.3.0 ambiguities... they're fixed in 1.3.1 thanks to https://github.com/JuliaLang/julia/pull/33743
210-
@static if VERSION === v"1.3.0"
211-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
212-
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
213-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
214-
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
215-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
216-
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
217-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
218-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
219-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
220-
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
221-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
222-
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
223-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
224-
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
225-
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
226-
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
227-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
228-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
229-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
230-
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
231-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
232-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
233-
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
234-
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
235-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
236-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
237-
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T <: CublasComplex =
238-
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
239-
end
159+
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}) where T<:CublasFloat = gemm_wrapper!(C, 'N', 'N', A, B)
160+
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
161+
gemm_wrapper!(C, 'T', 'N', parent(trA), B)
162+
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
163+
gemm_wrapper!(C, 'N', 'T', A, parent(trB))
164+
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
165+
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB))
166+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasReal =
167+
gemm_wrapper!(C, 'T', 'N', parent(adjA), B)
168+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
169+
gemm_wrapper!(C, 'C', 'N', parent(adjA), B)
170+
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
171+
gemm_wrapper!(C, 'N', 'T', A, parent(adjB))
172+
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
173+
gemm_wrapper!(C, 'N', 'C', A, parent(adjB))
174+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, CuMatrix{T}}) where T<:CublasReal =
175+
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB))
176+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
177+
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB))
178+
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}) where T<:CublasReal =
179+
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB))
180+
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
181+
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB))
182+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
183+
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB))
184+
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T <: CublasFloat =
185+
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB))
186+
240187

241188
# TRSM
242189

test/blas.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,6 @@ end # level 1 testset
6565
dA = CuArray(A)
6666
@test_throws DimensionMismatch mul!(dy, dA, dx)
6767
end
68-
@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty)
69-
y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5)
70-
dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)
71-
mul!(dy, f(dA), dx, Ts(1), Ts(1))
72-
mul!(y, f(A), x, elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
73-
@test Array(dy) y
74-
end
7568
@testset "banded methods" begin
7669
# bands
7770
ku = 2
@@ -406,13 +399,6 @@ end # level 1 testset
406399
end
407400
end
408401
@testset "Level 3" begin
409-
@testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
410-
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
411-
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
412-
mul!(dC, f(dA), g(dB), Ts(1), Ts(2))
413-
mul!(C, f(A), g(B), elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
414-
@test Array(dC) C
415-
end
416402
A = rand(elty,m,k)
417403
B = rand(elty,k,n)
418404
Bbad = rand(elty,k+1,n+1)

0 commit comments

Comments
 (0)