Skip to content

Commit 5e51422

Browse files
committed
support adjoint(::Diagonal{<:CuVector{<:Complex}}) (e.g. mul! resulted in scalar-indexing) + test
1 parent 58123fa commit 5e51422

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ function LinearAlgebra.inv(D::Diagonal{T, <:CuArray{T}}) where {T}
339339
Diagonal(Di)
340340
end
341341

342+
LinearAlgebra.adjoint(D::Diagonal{T, <:CuVector{T}}) where T <: Complex = Diagonal(map(adjoint, D.diag))
343+
342344
LinearAlgebra.rdiv!(A::CuArray, D::Diagonal) = _rdiv!(A, A, D)
343345

344346
Base.:/(A::CuArray, D::Diagonal) = _rdiv!(similar(A, typeof(oneunit(eltype(A)) / oneunit(eltype(D)))), A, D)

test/libraries/cublas/extensions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,23 @@ k = 13
538538
mul!(d_XA, d_X, d_A)
539539
Array(d_XA) Diagonal(x) * A
540540

541+
XA = rand(elty,m,n)
542+
d_XA = CuArray(XA)
543+
d_X = Diagonal(d_x)
544+
mul!(d_XA, d_X', d_A)
545+
Array(d_XA) Diagonal(x)' * A
546+
541547
AY = rand(elty,m,n)
542548
d_AY = CuArray(AY)
543549
d_Y = Diagonal(d_y)
544550
mul!(d_AY, d_A, d_Y)
545551
Array(d_AY) A * Diagonal(y)
552+
553+
AY = rand(elty,m,n)
554+
d_AY = CuArray(AY)
555+
d_Y = Diagonal(d_y)
556+
mul!(d_AY, d_A, d_Y')
557+
Array(d_AY) A * Diagonal(y)'
546558

547559
YA = rand(elty,n,m)
548560
d_YA = CuArray(YA)

0 commit comments

Comments
 (0)