Skip to content

Commit b5a8bb0

Browse files
authored
Backport to release-1.12 (#1420)
- [x] #1305 - [x] #1419
2 parents ad868c7 + 7f4b476 commit b5a8bb0

File tree

3 files changed

+163
-41
lines changed

3 files changed

+163
-41
lines changed

src/symmetriceigen.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
# preserve HermOrSym wrapper
44
# Call `copytrito!` instead of `copy_similar` to only copy the matching triangular half
5-
eigencopy_oftype(A::Hermitian, S) = Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
6-
eigencopy_oftype(A::Symmetric, S) = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
7-
eigencopy_oftype(A::Symmetric{<:Complex}, S) = copyto!(similar(parent(A), S), A)
5+
eigencopy_oftype(A::Hermitian, ::Type{S}) where S = Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
6+
eigencopy_oftype(A::Symmetric, ::Type{S}) where S = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
7+
eigencopy_oftype(A::Symmetric{<:Complex}, ::Type{S}) where S = copyto!(similar(parent(A), S), A)
88

99
"""
1010
default_eigen_alg(A)

src/triangular.jl

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
230230
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
231231
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false
232232

233-
@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
234-
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
235-
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
236-
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
233+
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
234+
if _shouldforwardindex(A, i, j)
235+
A.data[i,j]
236+
else
237+
@boundscheck checkbounds(A, i, j)
238+
ifelse(i == j, oneunit(T), zero(T))
239+
end
240+
end
241+
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
242+
if _shouldforwardindex(A, i, j)
243+
A.data[i,j]
244+
else
245+
@boundscheck checkbounds(A, i, j)
246+
@inbounds diagzero(A,i,j)
247+
end
248+
end
237249

238250
_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
239251
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -242,62 +254,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
242254

243255
# these specialized getindex methods enable constant-propagation of the band
244256
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
245-
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
257+
if _shouldforwardindex(A, b)
258+
A.data[b]
259+
else
260+
@boundscheck checkbounds(A, b)
261+
ifelse(b.band == 0, oneunit(T), zero(T))
262+
end
246263
end
247264
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
248-
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
265+
if _shouldforwardindex(A, b)
266+
A.data[b]
267+
else
268+
@boundscheck checkbounds(A, b)
269+
@inbounds diagzero(A, b)
270+
end
249271
end
250272

251-
_zero_triangular_half_str(T::Type) = T <: UpperOrUnitUpperTriangular ? "lower" : "upper"
252-
253-
@noinline function throw_nonzeroerror(T::DataType, @nospecialize(x), i, j)
254-
Ts = _zero_triangular_half_str(T)
255-
Tn = nameof(T)
273+
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
274+
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
275+
nstr = Tn === :UpperTriangular ? "n" : ""
256276
throw(ArgumentError(
257-
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
277+
LazyString(
278+
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
279+
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
280+
)
281+
)
258282
end
259-
@noinline function throw_nononeerror(T::DataType, @nospecialize(x), i, j)
260-
Tn = nameof(T)
283+
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
261284
throw(ArgumentError(
262-
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
285+
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
263286
end
264287

265288
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
266-
if i > j
267-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
268-
else
289+
if _shouldforwardindex(A, i, j)
269290
A.data[i,j] = x
291+
else
292+
@boundscheck checkbounds(A, i, j)
293+
# the value must be convertible to the eltype for setindex! to be meaningful
294+
# however, the converted value is unused, and the compiler is free to remove
295+
# the conversion if the call is guaranteed to succeed
296+
convert(eltype(A), x)
297+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
270298
end
271299
return A
272300
end
273301

274302
@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
275-
if i > j
276-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
277-
elseif i == j
278-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
279-
else
303+
if _shouldforwardindex(A, i, j)
280304
A.data[i,j] = x
305+
else
306+
@boundscheck checkbounds(A, i, j)
307+
# the value must be convertible to the eltype for setindex! to be meaningful
308+
# however, the converted value is unused, and the compiler is free to remove
309+
# the conversion if the call is guaranteed to succeed
310+
convert(eltype(A), x)
311+
if i == j # diagonal
312+
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
313+
else
314+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
315+
end
281316
end
282317
return A
283318
end
284319

285320
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
286-
if i < j
287-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
288-
else
321+
if _shouldforwardindex(A, i, j)
289322
A.data[i,j] = x
323+
else
324+
@boundscheck checkbounds(A, i, j)
325+
# the value must be convertible to the eltype for setindex! to be meaningful
326+
# however, the converted value is unused, and the compiler is free to remove
327+
# the conversion if the call is guaranteed to succeed
328+
convert(eltype(A), x)
329+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
290330
end
291331
return A
292332
end
293333

294334
@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
295-
if i < j
296-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
297-
elseif i == j
298-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
299-
else
335+
if _shouldforwardindex(A, i, j)
300336
A.data[i,j] = x
337+
else
338+
@boundscheck checkbounds(A, i, j)
339+
# the value must be convertible to the eltype for setindex! to be meaningful
340+
# however, the converted value is unused, and the compiler is free to remove
341+
# the conversion if the call is guaranteed to succeed
342+
convert(eltype(A), x)
343+
if i == j # diagonal
344+
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
345+
else
346+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
347+
end
301348
end
302349
return A
303350
end
@@ -542,7 +589,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
542589
@eval @inline function _copyto!(A::$UT, B::$T)
543590
for dind in diagind(A, IndexStyle(A))
544591
if A[dind] != B[dind]
545-
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
592+
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
546593
end
547594
end
548595
_copyto!($T(parent(A)), B)
@@ -696,7 +743,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
696743
checksize1(A, B)
697744
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
698745
for j in axes(B.data,2)
699-
@inbounds _modify!(_add, c, A, (j,j))
746+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
700747
for i in firstindex(B.data,1):(j - 1)
701748
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
702749
end
@@ -707,7 +754,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
707754
checksize1(A, B)
708755
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
709756
for j in axes(B.data,2)
710-
@inbounds _modify!(_add, c, A, (j,j))
757+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
711758
for i in firstindex(B.data,1):(j - 1)
712759
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
713760
end
@@ -738,7 +785,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
738785
checksize1(A, B)
739786
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
740787
for j in axes(B.data,2)
741-
@inbounds _modify!(_add, c, A, (j,j))
788+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
742789
for i in (j + 1):lastindex(B.data,1)
743790
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
744791
end
@@ -749,7 +796,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
749796
checksize1(A, B)
750797
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
751798
for j in axes(B.data,2)
752-
@inbounds _modify!(_add, c, A, (j,j))
799+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
753800
for i in (j + 1):lastindex(B.data,1)
754801
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
755802
end

test/triangular.jl

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,11 +644,11 @@ end
644644
@testset "error message" begin
645645
A = UpperTriangular(Ap)
646646
B = UpperTriangular(Bp)
647-
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
647+
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)
648648

649649
A = LowerTriangular(Ap)
650650
B = LowerTriangular(Bp)
651-
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
651+
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
652652
end
653653
end
654654

@@ -944,6 +944,81 @@ end
944944
@test 2\U == 2\M
945945
@test U*2 == M*2
946946
@test 2*U == 2*M
947+
948+
U2 = copy(U)
949+
@test rmul!(U, 1) == U2
950+
@test lmul!(1, U) == U2
951+
end
952+
953+
@testset "indexing checks" begin
954+
P = [1 2; 3 4]
955+
@testset "getindex" begin
956+
U = UnitUpperTriangular(P)
957+
@test_throws BoundsError U[0,0]
958+
@test_throws BoundsError U[1,0]
959+
@test_throws BoundsError U[BandIndex(0,0)]
960+
@test_throws BoundsError U[BandIndex(-1,0)]
961+
962+
U = UpperTriangular(P)
963+
@test_throws BoundsError U[1,0]
964+
@test_throws BoundsError U[BandIndex(-1,0)]
965+
966+
L = UnitLowerTriangular(P)
967+
@test_throws BoundsError L[0,0]
968+
@test_throws BoundsError L[0,1]
969+
@test_throws BoundsError U[BandIndex(0,0)]
970+
@test_throws BoundsError U[BandIndex(1,0)]
971+
972+
L = LowerTriangular(P)
973+
@test_throws BoundsError L[0,1]
974+
@test_throws BoundsError L[BandIndex(1,0)]
975+
end
976+
@testset "setindex!" begin
977+
A = SizedArrays.SizedArray{(2,2)}(P)
978+
M = fill(A, 2, 2)
979+
U = UnitUpperTriangular(M)
980+
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
981+
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
982+
@test_throws non_unit_msg U[1,1] = A
983+
L = UnitLowerTriangular(M)
984+
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
985+
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
986+
@test_throws non_unit_msg L[1,1] = A
987+
988+
for UT in (UnitUpperTriangular, UpperTriangular)
989+
U = UT(M)
990+
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
991+
end
992+
for LT in (UnitLowerTriangular, LowerTriangular)
993+
L = LT(M)
994+
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
995+
end
996+
997+
U = UnitUpperTriangular(P)
998+
@test_throws BoundsError U[0,0] = 1
999+
@test_throws BoundsError U[1,0] = 0
1000+
1001+
U = UpperTriangular(P)
1002+
@test_throws BoundsError U[1,0] = 0
1003+
1004+
L = UnitLowerTriangular(P)
1005+
@test_throws BoundsError L[0,0] = 1
1006+
@test_throws BoundsError L[0,1] = 0
1007+
1008+
L = LowerTriangular(P)
1009+
@test_throws BoundsError L[0,1] = 0
1010+
end
1011+
end
1012+
1013+
@testset "unit triangular l/rdiv!" begin
1014+
A = rand(3,3)
1015+
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
1016+
(UnitLowerTriangular, LowerTriangular))
1017+
UnitTri = UT(A)
1018+
Tri = T(LinearAlgebra.full(UnitTri))
1019+
@test 2 \ UnitTri 2 \ Tri
1020+
@test UnitTri / 2 Tri / 2
1021+
end
9471022
end
9481023

9491024
end # module TestTriangular

0 commit comments

Comments
 (0)