Skip to content

Commit 861f7a2

Browse files
mbaumanKristofferC
authored andcommitted
fix #31674, error when storing nonzeros into structural zeros with .= (#31678)
Previously, broadcasted assignment (`.=`) would happily ignore all nonstructured portions of the destination, regardless of whether the broadcasted expression would actually evaluate to zero or not. This changes these in-place methods to use the same infrastructure that out-of-place broadcast uses to determine the result type. If we are unsure of the structural properties of the output, we fall back to the generic implementation, which will attempt to store into every single location of the destination -- including those structural zeros. Thus we now error in cases where we generate nonzeros in those locations. (cherry picked from commit 6bd3967)
1 parent f254c2e commit 861f7a2

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
102102
end
103103

104104
function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
105+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
105106
axs = axes(dest)
106107
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
107108
for i in axs[1]
@@ -111,6 +112,7 @@ function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
111112
end
112113

113114
function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
115+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
114116
axs = axes(dest)
115117
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
116118
for i in axs[1]
@@ -129,18 +131,22 @@ function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
129131
end
130132

131133
function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
134+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
132135
axs = axes(dest)
133136
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
134137
for i in axs[1]
135138
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
136139
end
137140
for i = 1:size(dest, 1)-1
138-
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
141+
v = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
142+
v == Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) || throw(ArgumentError("broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
143+
dest.ev[i] = v
139144
end
140145
return dest
141146
end
142147

143148
function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
149+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
144150
axs = axes(dest)
145151
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
146152
for i in axs[1]
@@ -154,6 +160,7 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
154160
end
155161

156162
function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
163+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
157164
axs = axes(dest)
158165
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
159166
for j in axs[2]
@@ -165,6 +172,7 @@ function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
165172
end
166173

167174
function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
175+
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
168176
axs = axes(dest)
169177
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
170178
for j in axs[2]

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,37 @@ end
5151
A = rand(N, N)
5252
sA = A + copy(A')
5353
D = Diagonal(rand(N))
54-
B = Bidiagonal(rand(N), rand(N - 1), :U)
54+
Bu = Bidiagonal(rand(N), rand(N - 1), :U)
55+
Bl = Bidiagonal(rand(N), rand(N - 1), :L)
5556
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
57+
= LowerTriangular(rand(N,N))
58+
= UpperTriangular(rand(N,N))
59+
5660
@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
57-
@test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U)
61+
@test broadcast!(sin, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
62+
@test broadcast!(sin, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
5863
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
64+
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
65+
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
5966
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
60-
@test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U)
67+
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
68+
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
6169
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
70+
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
71+
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
72+
73+
@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
74+
@test_throws ArgumentError broadcast!(cos, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
75+
@test_throws ArgumentError broadcast!(cos, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
76+
@test_throws ArgumentError broadcast!(cos, copy(T), T) == Tridiagonal(sin.(T))
77+
@test_throws ArgumentError broadcast!(cos, copy(◣), ◣) == LowerTriangular(sin.(◣))
78+
@test_throws ArgumentError broadcast!(cos, copy(◥), ◥) == UpperTriangular(sin.(◥))
79+
@test_throws ArgumentError broadcast!(+, copy(D), D, A) == Diagonal(broadcast(*, D, A))
80+
@test_throws ArgumentError broadcast!(+, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
81+
@test_throws ArgumentError broadcast!(+, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
82+
@test_throws ArgumentError broadcast!(+, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
83+
@test_throws ArgumentError broadcast!(+, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
84+
@test_throws ArgumentError broadcast!(+, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
6285
end
6386

6487
@testset "map[!] over combinations of structured matrices" begin

0 commit comments

Comments
 (0)