Skip to content

Commit 5ea3ed5

Browse files
authored
Generalize 3-arg dot to HermOrSym (#1410)
1 parent 5d05175 commit 5ea3ed5

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

src/symmetric.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian}
230230
wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
231231
wrappertype(::Hermitian) = Hermitian
232232

233+
hswrapperop(::Symmetric) = symmetric
234+
hswrapperop(::Hermitian) = hermitian
235+
233236
nonhermitianwrappertype(::SymSymTri{<:Real}) = Symmetric
234237
nonhermitianwrappertype(::Hermitian{<:Real}) = Symmetric
235238
nonhermitianwrappertype(::Hermitian) = identity
@@ -738,27 +741,28 @@ function mul(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloa
738741
convert(AbstractMatrix{T}, B))
739742
end
740743

741-
function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
744+
function dot(x::AbstractVector, A::HermOrSym, y::AbstractVector)
742745
require_one_based_indexing(x, y)
743746
n = length(x)
744747
(n == length(y) == size(A, 1)) || throw(DimensionMismatch())
745748
data = A.data
746-
r = dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
749+
s = dot(first(x), first(A), first(y))
750+
r = zero(s+s)
747751
iszero(n) && return r
748752
if A.uplo == 'U'
749753
@inbounds for j = 1:length(y)
750-
r += dot(x[j], real(data[j,j]), y[j])
754+
r += dot(x[j], hswrapperop(A)(data[j,j], :U), y[j])
751755
@simd for i = 1:j-1
752756
Aij = data[i,j]
753-
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
757+
r += dot(x[i], Aij, y[j]) + dot(x[j], _conjugation(A)(Aij), y[i])
754758
end
755759
end
756760
else # A.uplo == 'L'
757761
@inbounds for j = 1:length(y)
758-
r += dot(x[j], real(data[j,j]), y[j])
762+
r += dot(x[j], hswrapperop(A)(data[j,j], :L), y[j])
759763
@simd for i = j+1:length(y)
760764
Aij = data[i,j]
761-
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
765+
r += dot(x[i], Aij, y[j]) + dot(x[j], _conjugation(A)(Aij), y[i])
762766
end
763767
end
764768
end

test/symmetric.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,13 +582,21 @@ end
582582

583583
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
584584
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
585-
@testset "dot Hermitian quaternion #52318" begin
586-
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
585+
@testset "dot Hermitian quaternion" begin
586+
A, B = [(randn(Quaternion{Float64},4,4)) |> t -> t + t' for i in 1:2]
587587
@test A == Hermitian(A) && B == Hermitian(B)
588588
@test dot(A, B) dot(Hermitian(A), Hermitian(B))
589-
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
589+
A, B = [(randn(Quaternion{Float64},4,4)) |> t -> t + transpose(t) for i in 1:2]
590590
@test A == Symmetric(A) && B == Symmetric(B)
591591
@test dot(A, B) dot(Symmetric(A), Symmetric(B))
592+
A = randn(Quaternion{Float64}, 4, 4)
593+
x = randn(Quaternion{Float64}, 4)
594+
y = randn(Quaternion{Float64}, 4)
595+
for t in (Symmetric, Hermitian), uplo in (:U, :L)
596+
M = t(A, uplo)
597+
N = Matrix(M)
598+
@test dot(x, M, y) dot(x, M*y) dot(x, N, y)
599+
end
592600
end
593601

594602
# let's make sure the analogous bug will not show up with kronecker products
@@ -610,6 +618,18 @@ end
610618
end
611619
end
612620

621+
@testset "3-arg dot with Symmetric/Hermitian matrix of matrices" begin
622+
for m in (Symmetric([randn(ComplexF64, 2, 2) for i in 1:2, j in 1:2]),
623+
Symmetric([randn(ComplexF64, 2, 2) for i in 1:2, j in 1:2], :L),
624+
Hermitian([randn(ComplexF64, 2, 2) for i in 1:2, j in 1:2]),
625+
Hermitian([randn(ComplexF64, 2, 2) for i in 1:2, j in 1:2], :L)
626+
)
627+
x = [randn(ComplexF64, 2) for i in 1:2]
628+
y = [randn(ComplexF64, 2) for i in 1:2]
629+
@test dot(x, m, y) dot(x, m*y) dot(x, Matrix(m), y)
630+
end
631+
end
632+
613633
#Issue #7647: test xsyevr, xheevr, xstevr drivers.
614634
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
615635
(Symmetric(diagm(0 => 1.0:3.0)),

0 commit comments

Comments
 (0)