Skip to content

Commit df94a65

Browse files
authored
Compare Cholesky and BunchKaufman by properties (#54509)
1 parent c28a9de commit df94a65

File tree

6 files changed

+53
-14
lines changed

6 files changed

+53
-14
lines changed

stdlib/LinearAlgebra/src/bunchkaufman.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,17 @@ end
297297
Base.propertynames(B::BunchKaufman, private::Bool=false) =
298298
(:p, :P, :L, :U, :D, (private ? fieldnames(typeof(B)) : ())...)
299299

300+
function Base.:(==)(B1::BunchKaufman, B2::BunchKaufman)
301+
# check for the equality between properties instead of fields
302+
B1.p == B2.p || return false
303+
if B1.uplo == 'L'
304+
B1.L == B2.L || return false
305+
else
306+
B1.U == B2.U || return false
307+
end
308+
return (B1.D == B2.D)
309+
end
310+
300311
function getproperties!(B::BunchKaufman{T,<:StridedMatrix}) where {T<:BlasFloat}
301312
# NOTE: Unlike in the 'getproperty' function, in this function L/U and D are computed in place.
302313
if B.rook

stdlib/LinearAlgebra/src/cholesky.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,11 @@ end
546546
Base.propertynames(F::Cholesky, private::Bool=false) =
547547
(:U, :L, :UL, (private ? fieldnames(typeof(F)) : ())...)
548548

549+
function Base.:(==)(C1::Cholesky, C2::Cholesky)
550+
C1.uplo == C2.uplo || return false
551+
C1.uplo == 'L' ? (C1.L == C2.L) : (C1.U == C2.U)
552+
end
553+
549554
function getproperty(C::CholeskyPivoted{T}, d::Symbol) where {T}
550555
Cfactors = getfield(C, :factors)
551556
Cuplo = getfield(C, :uplo)
@@ -569,6 +574,11 @@ end
569574
Base.propertynames(F::CholeskyPivoted, private::Bool=false) =
570575
(:U, :L, :p, :P, (private ? fieldnames(typeof(F)) : ())...)
571576

577+
function Base.:(==)(C1::CholeskyPivoted, C2::CholeskyPivoted)
578+
(C1.uplo == C2.uplo && C1.p == C2.p) || return false
579+
C1.uplo == 'L' ? (C1.L == C2.L) : (C1.U == C2.U)
580+
end
581+
572582
issuccess(C::Union{Cholesky,CholeskyPivoted}) = C.info == 0
573583

574584
adjoint(C::Union{Cholesky,CholeskyPivoted}) = C

stdlib/LinearAlgebra/src/eigen.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,19 @@ function show(io::IO, mime::MIME{Symbol("text/plain")}, F::Union{Eigen,Generaliz
658658
show(io, mime, F.vectors)
659659
end
660660

661-
function Base.hash(F::Eigen, h::UInt)
662-
return hash(F.values, hash(F.vectors, hash(Eigen, h)))
663-
end
664-
function Base.:(==)(A::Eigen, B::Eigen)
665-
return A.values == B.values && A.vectors == B.vectors
666-
end
667-
function Base.isequal(A::Eigen, B::Eigen)
668-
return isequal(A.values, B.values) && isequal(A.vectors, B.vectors)
661+
_equalcheck(f, Avalues, Avectors, Bvalues, Bvectors) = f(Avalues, Bvalues) && f(Avectors, Bvectors)
662+
for T in (Eigen, GeneralizedEigen)
663+
@eval begin
664+
function Base.hash(F::$T, h::UInt)
665+
return hash(F.values, hash(F.vectors, hash($T, h)))
666+
end
667+
function Base.:(==)(A::$T, B::$T)
668+
return _equalcheck(==, A..., B...)
669+
end
670+
function Base.isequal(A::$T, B::$T)
671+
return _equalcheck(isequal, A..., B...)
672+
end
673+
end
669674
end
670675

671676
# Conversion methods

stdlib/LinearAlgebra/test/cholesky.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ end
553553
M = Matrix{BigFloat}(undef, 2, 2)
554554
M[1,1] = M[2,2] = M[1+(uplo=='L'), 1+(uplo=='U')] = 3
555555
C = Cholesky(M, uplo, 0)
556+
@test C == C
556557
@test C.L == C.U'
557558
# parameters are arbitrary
558559
C = CholeskyPivoted(M, uplo, [1,2], 2, 0.0, 0)

stdlib/LinearAlgebra/test/eigen.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,22 @@ end
212212
end
213213

214214
@testset "equality of eigen factorizations" begin
215-
A = randn(3, 3)
216-
@test eigen(A) == eigen(A)
217-
@test hash(eigen(A)) == hash(eigen(A))
218-
@test isequal(eigen(A), eigen(A))
215+
A1 = Float32[1 0; 0 2]
216+
A2 = Float64[1 0; 0 2]
217+
EA1 = eigen(A1)
218+
EA2 = eigen(A2)
219+
@test EA1 == EA2
220+
@test hash(EA1) == hash(EA2)
221+
@test isequal(EA1, EA2)
222+
223+
# trivial RHS to ensure that values match exactly
224+
B1 = Float32[1 0; 0 1]
225+
B2 = Float64[1 0; 0 1]
226+
EA1B1 = eigen(A1, B1)
227+
EA2B2 = eigen(A2, B2)
228+
@test EA1B1 == EA2B2
229+
@test hash(EA1B1) == hash(EA2B2)
230+
@test isequal(EA1B1, EA2B2)
219231
end
220232

221233
@testset "Float16" begin

stdlib/LinearAlgebra/test/factorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ using Test, LinearAlgebra
3737
return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x
3838
end...)
3939

40-
@test F == G broken=!(f === eigen || f === qr)
41-
@test isequal(F, G) broken=!(f === eigen || f === qr)
40+
@test F == G broken=!(f === eigen || f === qr || f == bunchkaufman || f == cholesky || F isa CholeskyPivoted)
41+
@test isequal(F, G) broken=!(f === eigen || f === qr || f == bunchkaufman || f == cholesky || F isa CholeskyPivoted)
4242
@test hash(F) == hash(G)
4343
end
4444

0 commit comments

Comments
 (0)