Skip to content

Commit dffe119

Browse files
jmertandreasnoack
authored andcommitted
Sparsity-preserving outer products (#24980)
* Add indtype and nnz definitions for SparseColumnView * Handle sparse outer products specially in broadcast * Add specialized kron for sparse outer products * Add tests * Support unitful types * Address review comments. * Change is_specialcase_sparse_broadcast -> can_skip_sparsification. * Lift parent(y) to one function earlier for clarify * Simply call _copy instead of passing through the broadcast machinery again
1 parent f8f2045 commit dffe119

File tree

5 files changed

+81
-3
lines changed

5 files changed

+81
-3
lines changed

stdlib/SparseArrays/src/higherorderfns.jl

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto!
88

99
using Base: front, tail, to_shape
1010
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
11-
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
11+
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange,
12+
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros
1213
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
1314
using LinearAlgebra
1415

@@ -92,6 +93,9 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor
9293
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
9394
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)
9495

96+
can_skip_sparsification(f, rest...) = false
97+
can_skip_sparsification(::typeof(*), ::SparseVectorUnion, ::AdjOrTransSparseVectorUnion) = true
98+
9599
# Dispatch on broadcast operations by number of arguments
96100
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
97101
Broadcasted{Style,Axes,F,Tuple{}}
@@ -810,6 +814,48 @@ end
810814
_finishempty!(C::SparseVector) = C
811815
_finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C)
812816

817+
# special case - vector outer product
818+
_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y)
819+
@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, parent(y))
820+
@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, parent(y))
821+
function _outer(trans::Tf, x, y) where Tf
822+
nx = length(x)
823+
ny = length(y)
824+
rowvalx = nonzeroinds(x)
825+
rowvaly = nonzeroinds(y)
826+
nzvalsx = nonzeros(x)
827+
nzvalsy = nonzeros(y)
828+
nnzx = length(nzvalsx)
829+
nnzy = length(nzvalsy)
830+
831+
nnzC = nnzx * nnzy
832+
Tv = typeof(oneunit(eltype(x)) * oneunit(eltype(y)))
833+
Ti = promote_type(indtype(x), indtype(y))
834+
colptrC = zeros(Ti, ny + 1)
835+
rowvalC = Vector{Ti}(undef, nnzC)
836+
nzvalsC = Vector{Tv}(undef, nnzC)
837+
838+
idx = 0
839+
@inbounds colptrC[1] = 1
840+
@inbounds for jj = 1:nnzy
841+
yval = nzvalsy[jj]
842+
iszero(yval) && continue
843+
col = rowvaly[jj]
844+
yval = trans(yval)
845+
846+
for ii = 1:nnzx
847+
xval = nzvalsx[ii]
848+
iszero(xval) && continue
849+
idx += 1
850+
colptrC[col+1] += 1
851+
rowvalC[idx] = rowvalx[ii]
852+
nzvalsC[idx] = xval * yval
853+
end
854+
end
855+
cumsum!(colptrC, colptrC)
856+
857+
return SparseMatrixCSC(nx, ny, colptrC, rowvalC, nzvalsC)
858+
end
813859

814860
# (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices
815861
function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N}
@@ -1079,8 +1125,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
10791125

10801126
function copy(bc::Broadcasted{PromoteToSparse})
10811127
bcf = flatten(bc)
1082-
if is_supported_sparse_broadcast(bcf.args...)
1083-
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
1128+
if can_skip_sparsification(bcf.f, bcf.args...)
1129+
return _copy(bcf.f, bcf.args...)
1130+
elseif is_supported_sparse_broadcast(bcf.args...)
1131+
return _copy(bcf.f, map(_sparsifystructured, bcf.args)...)
10841132
else
10851133
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
10861134
end

stdlib/SparseArrays/src/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,9 @@ kron(x::SparseVector, A::SparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
11981198
kron(A::Union{SparseVector,SparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
11991199
kron(A::VecOrMat, B::Union{SparseVector,SparseMatrixCSC}) = kron(sparse(A), B)
12001200

1201+
# sparse outer product
1202+
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B
1203+
12011204
## det, inv, cond
12021205

12031206
inv(A::SparseMatrixCSC) = error("The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please convert your matrix to a dense matrix.")

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
3434
# union of such a view and a SparseVector so we define an alias for such a union as well
3535
const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false}
3636
const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}}
37+
const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}}
3738

3839
### Basic properties
3940

@@ -58,6 +59,11 @@ function nonzeroinds(x::SparseColumnView)
5859
return y
5960
end
6061

62+
indtype(x::SparseColumnView) = indtype(parent(x))
63+
function nnz(x::SparseColumnView)
64+
rowidx, colidx = parentindices(x)
65+
return length(nzrange(parent(x), colidx))
66+
end
6167

6268
## similar
6369
#

stdlib/SparseArrays/test/higherorderfns.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,4 +656,19 @@ using SparseArrays.HigherOrderFns: SparseVecStyle
656656
@test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err))
657657
end
658658

659+
@testset "Sparse outer product, for type $T and vector $op" for
660+
op in (transpose, adjoint),
661+
T in (Float64, ComplexF64)
662+
m, n, p = 100, 250, 0.1
663+
A = sprand(T, m, n, p)
664+
a, b = view(A, :, 1), sprand(T, m, p)
665+
av, bv = Vector(a), Vector(b)
666+
v = @inferred a .* op(b)
667+
w = @inferred b .* op(a)
668+
@test issparse(v)
669+
@test issparse(w)
670+
@test v == av .* op(bv)
671+
@test w == bv .* op(av)
672+
end
673+
659674
end # module

stdlib/SparseArrays/test/sparse.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ end
352352
for (m,n) in ((5,10), (13,8), (14,10))
353353
a = sprand(m, 5, 0.4); a_d = Matrix(a)
354354
b = sprand(n, 6, 0.3); b_d = Matrix(b)
355+
v = view(a, :, 1); v_d = Vector(v)
355356
x = sprand(m, 0.4); x_d = Vector(x)
356357
y = sprand(n, 0.3); y_d = Vector(y)
357358
# mat ⊗ mat
@@ -370,6 +371,11 @@ end
370371
@test Array(kron(x, b)) == kron(x_d, b_d)
371372
@test Array(kron(x_d, b)) == kron(x_d, b_d)
372373
@test Array(kron(x, b_d)) == kron(x_d, b_d)
374+
# vec ⊗ vec'
375+
@test issparse(kron(v, y'))
376+
@test issparse(kron(x, y'))
377+
@test Array(kron(v, y')) == kron(v_d, y_d')
378+
@test Array(kron(x, y')) == kron(x_d, y_d')
373379
# test different types
374380
z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z)
375381
@test Vector(kron(x, z)) == kron(x_d, z_d)

0 commit comments

Comments
 (0)