Skip to content

Commit 44574b8

Browse files
authored
Towards more general truncation and slicing (#158)
1 parent 21eb0e9 commit 44574b8

File tree

6 files changed

+104
-47
lines changed

6 files changed

+104
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.21"
4+
version = "0.7.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
6868
end
6969
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
7070

71+
# TODO: Move this to a `BlockArraysExtensions` library.
72+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndices)
73+
# TODO: Is this a good definition? It ignores `indices.indices`.
74+
return a[indices.blocks]
75+
end
76+
7177
# Generalization of to `BlockArrays._blockslice`:
7278
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L13-L14
7379
# Used by `BlockArrays.unblock`, which is used in `to_indices`

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ end
314314
# `Base.getindex(a::Block, b...)`.
315315
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
316316
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
317+
_getindex(a::Block{N}, b::Vararg{AbstractUnitRange{<:Integer},N}) where {N} = a[b...]
318+
_getindex(a::Block{N}, b::Vararg{AbstractVector,N}) where {N} = BlockIndexVector(a, b)
317319
# Fix ambiguity.
318320
_getindex(a::Block{0}) = a[]
319321

@@ -366,13 +368,21 @@ function blockedunitrange_getindices(
366368
a::AbstractBlockedUnitRange,
367369
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
368370
)
369-
return mortar(map(b -> a[b], blocks(indices)))
371+
blks = map(b -> a[b], blocks(indices))
372+
# Preserve any extra structure in the axes, like a
373+
# Kronecker structure, symmetry sectors, etc.
374+
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
375+
return mortar(blks, (ax,))
370376
end
371377
function blockedunitrange_getindices(
372378
a::AbstractBlockedUnitRange,
373379
indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
374380
)
375-
return mortar(map(b -> a[b], blocks(indices)))
381+
blks = map(b -> a[b], blocks(indices))
382+
# Preserve any extra structure in the axes, like a
383+
# Kronecker structure, symmetry sectors, etc.
384+
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
385+
return mortar(blks, (ax,))
376386
end
377387

378388
# This is a specialization of `BlockArrays.unblock`:

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, norm, tr
1+
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, diag, norm, tr
22

33
# Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184
44
# but also takes the dual of the axes.
@@ -33,6 +33,24 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
3333
return tr_a
3434
end
3535

36+
# TODO: Define in DiagonalArrays.jl.
37+
function diagaxis(a::AbstractArray)
38+
LinearAlgebra.checksquare(a)
39+
return axes(a, 1)
40+
end
41+
function LinearAlgebra.diag(a::AnyAbstractBlockSparseMatrix)
42+
# TODO: Add `checkblocksquare` to also check it is square blockwise.
43+
LinearAlgebra.checksquare(a)
44+
diagaxes = map(blockdiagindices(a)) do I
45+
return diagaxis(@view(a[I]))
46+
end
47+
r = blockrange(diagaxes)
48+
stored_blocks = Dict((
49+
Tuple(I)[1] => diag(@view!(a[I])) for I in eachstoredblockdiagindex(a)
50+
))
51+
return blocksparse(stored_blocks, (r,))
52+
end
53+
3654
# TODO: Define `SparseArraysBase.isdiag`, define as
3755
# `isdiag(blocks(a))`.
3856
function blockisdiag(a::AbstractArray)

src/factorizations/truncation.jl

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
2-
3-
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
4-
D = BlockSparseVector{T}(undef, axes(A, 1))
5-
for I in eachblockstoredindex(A)
6-
if ==(Int.(Tuple(I))...)
7-
D[Tuple(I)[1]] = diagview(A[I])
8-
end
9-
end
10-
return D
11-
end
1+
using MatrixAlgebraKit:
2+
TruncationStrategy,
3+
diagview,
4+
eig_trunc!,
5+
eigh_trunc!,
6+
findtruncated,
7+
svd_trunc!,
8+
truncate!
129

1310
"""
1411
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
@@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!(
2724
strategy::TruncationStrategy,
2825
)
2926
# TODO assert blockdiagonal
30-
return MatrixAlgebraKit.truncate!(
27+
return truncate!(
3128
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
3229
)
3330
end
@@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
3835
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
3936
strategy::TruncationStrategy,
4037
)
41-
return MatrixAlgebraKit.truncate!(
42-
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
43-
)
38+
return truncate!($f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy))
4439
end
4540
end
4641
end
@@ -50,18 +45,30 @@ end
5045
function MatrixAlgebraKit.findtruncated(
5146
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
5247
)
53-
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
48+
ind = findtruncated(Vector(values), strategy.strategy)
5449
indexmask = falses(length(values))
5550
indexmask[ind] .= true
56-
return indexmask
51+
return to_truncated_indices(values, indexmask)
52+
end
53+
54+
# Allow customizing the indices output by `findtruncated`
55+
# based on the type of `values`, for example to preserve
56+
# a block or Kronecker structure.
57+
to_truncated_indices(values, I) = I
58+
function to_truncated_indices(values::AbstractBlockVector, I::AbstractVector{Bool})
59+
I′ = BlockedVector(I, blocklengths(axis(values)))
60+
blocks = map(BlockRange(values)) do b
61+
return _getindex(b, to_truncated_indices(values[b], I′[b]))
62+
end
63+
return blocks
5764
end
5865

5966
function MatrixAlgebraKit.truncate!(
6067
::typeof(svd_trunc!),
6168
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
6269
strategy::BlockPermutedDiagonalTruncationStrategy,
6370
)
64-
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
71+
I = findtruncated(diag(S), strategy)
6572
return (U[:, I], S[I, I], Vᴴ[I, :])
6673
end
6774
for f in [:eig_trunc!, :eigh_trunc!]
@@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
7178
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
7279
strategy::BlockPermutedDiagonalTruncationStrategy,
7380
)
74-
I = MatrixAlgebraKit.findtruncated(diagview(D), strategy)
81+
I = findtruncated(diag(D), strategy)
7582
return (D[I, I], V[:, I])
7683
end
7784
end

test/test_factorizations.jl

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,23 @@ test_params = Iterators.product(blockszs, eltypes)
146146
@test test_svd(a, usv_empty)
147147

148148
# test blockdiagonal
149+
rng = StableRNG(123)
149150
for i in LinearAlgebra.diagind(blocks(a))
150151
I = CartesianIndices(blocks(a))[i]
151-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
152+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
152153
end
153154
usv = svd_compact(a)
154155
@test test_svd(a, usv)
155156

156-
perm = Random.randperm(length(m))
157+
rng = StableRNG(123)
158+
perm = Random.randperm(rng, length(m))
157159
b = a[Block.(perm), Block.(1:length(n))]
158160
usv = svd_compact(b)
159161
@test test_svd(b, usv)
160162

161163
# test permuted blockdiagonal with missing row/col
162-
I_removed = rand(eachblockstoredindex(b))
164+
rng = StableRNG(123)
165+
I_removed = rand(rng, eachblockstoredindex(b))
163166
c = copy(b)
164167
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
165168
usv = svd_compact(c)
@@ -176,20 +179,23 @@ end
176179
@test test_svd(a, usv_empty; full=true)
177180

178181
# test blockdiagonal
182+
rng = StableRNG(123)
179183
for i in LinearAlgebra.diagind(blocks(a))
180184
I = CartesianIndices(blocks(a))[i]
181-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
185+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
182186
end
183187
usv = svd_full(a)
184188
@test test_svd(a, usv; full=true)
185189

186-
perm = Random.randperm(length(m))
190+
rng = StableRNG(123)
191+
perm = Random.randperm(rng, length(m))
187192
b = a[Block.(perm), Block.(1:length(n))]
188193
usv = svd_full(b)
189194
@test test_svd(b, usv; full=true)
190195

191196
# test permuted blockdiagonal with missing row/col
192-
I_removed = rand(eachblockstoredindex(b))
197+
rng = StableRNG(123)
198+
I_removed = rand(rng, eachblockstoredindex(b))
193199
c = copy(b)
194200
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
195201
usv = svd_full(c)
@@ -203,9 +209,10 @@ end
203209
a = BlockSparseArray{T}(undef, m, n)
204210

205211
# test blockdiagonal
212+
rng = StableRNG(123)
206213
for i in LinearAlgebra.diagind(blocks(a))
207214
I = CartesianIndices(blocks(a))[i]
208-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
215+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
209216
end
210217

211218
minmn = min(size(a)...)
@@ -236,7 +243,8 @@ end
236243
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
237244

238245
# test permuted blockdiagonal
239-
perm = Random.randperm(length(m))
246+
rng = StableRNG(123)
247+
perm = Random.randperm(rng, length(m))
240248
b = a[Block.(perm), Block.(1:length(n))]
241249
for trunc in (truncrank(r), trunctol(atol))
242250
U1, S1, V1ᴴ = svd_trunc(b; trunc)
@@ -270,8 +278,9 @@ end
270278
@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
271279
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
272280
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
273-
A[Block(1, 1)] = randn(T, i, k)
274-
A[Block(2, 2)] = randn(T, j, l)
281+
rng = StableRNG(123)
282+
A[Block(1, 1)] = randn(rng, T, i, k)
283+
A[Block(2, 2)] = randn(rng, T, j, l)
275284
Q, R = qr_compact(A)
276285
@test Matrix(Q'Q) LinearAlgebra.I
277286
@test A Q * R
@@ -281,8 +290,9 @@ end
281290
@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
282291
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
283292
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
284-
A[Block(1, 1)] = randn(T, i, k)
285-
A[Block(2, 2)] = randn(T, j, l)
293+
rng = StableRNG(123)
294+
A[Block(1, 1)] = randn(rng, T, i, k)
295+
A[Block(2, 2)] = randn(rng, T, j, l)
286296
Q, R = qr_full(A)
287297
Q′, R′ = qr_full(Matrix(A))
288298
@test size(Q) == size(Q′)
@@ -296,8 +306,9 @@ end
296306
@testset "lq_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
297307
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
298308
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
299-
A[Block(1, 1)] = randn(T, i, k)
300-
A[Block(2, 2)] = randn(T, j, l)
309+
rng = StableRNG(123)
310+
A[Block(1, 1)] = randn(rng, T, i, k)
311+
A[Block(2, 2)] = randn(rng, T, j, l)
301312
L, Q = lq_compact(A)
302313
@test Matrix(Q * Q') LinearAlgebra.I
303314
@test A L * Q
@@ -307,8 +318,9 @@ end
307318
@testset "lq_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
308319
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
309320
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
310-
A[Block(1, 1)] = randn(T, i, k)
311-
A[Block(2, 2)] = randn(T, j, l)
321+
rng = StableRNG(123)
322+
A[Block(1, 1)] = randn(rng, T, i, k)
323+
A[Block(2, 2)] = randn(rng, T, j, l)
312324
L, Q = lq_full(A)
313325
L′, Q′ = lq_full(Matrix(A))
314326
@test size(L) == size(L′)
@@ -321,8 +333,9 @@ end
321333

322334
@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
323335
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
324-
A[Block(1, 1)] = randn(T, 3, 2)
325-
A[Block(2, 2)] = randn(T, 4, 3)
336+
rng = StableRNG(123)
337+
A[Block(1, 1)] = randn(rng, T, 3, 2)
338+
A[Block(2, 2)] = randn(rng, T, 4, 3)
326339

327340
U, C = left_polar(A)
328341
@test U * C A
@@ -331,8 +344,9 @@ end
331344

332345
@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
333346
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
334-
A[Block(1, 1)] = randn(T, 2, 3)
335-
A[Block(2, 2)] = randn(T, 3, 4)
347+
rng = StableRNG(123)
348+
A[Block(1, 1)] = randn(rng, T, 2, 3)
349+
A[Block(2, 2)] = randn(rng, T, 3, 4)
336350

337351
C, U = right_polar(A)
338352
@test C * U A
@@ -341,8 +355,9 @@ end
341355

342356
@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
343357
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
344-
A[Block(1, 1)] = randn(T, 3, 2)
345-
A[Block(2, 2)] = randn(T, 4, 3)
358+
rng = StableRNG(123)
359+
A[Block(1, 1)] = randn(rng, T, 3, 2)
360+
A[Block(2, 2)] = randn(rng, T, 4, 3)
346361

347362
for kind in (:polar, :qr, :svd)
348363
U, C = left_orth(A; kind)
@@ -358,8 +373,9 @@ end
358373

359374
@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
360375
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
361-
A[Block(1, 1)] = randn(T, 2, 3)
362-
A[Block(2, 2)] = randn(T, 3, 4)
376+
rng = StableRNG(123)
377+
A[Block(1, 1)] = randn(rng, T, 2, 3)
378+
A[Block(2, 2)] = randn(rng, T, 3, 4)
363379

364380
for kind in (:lq, :polar, :svd)
365381
C, U = right_orth(A; kind)

0 commit comments

Comments
 (0)