Skip to content

Commit fae262c

Browse files
KlausCViralBShah
authored andcommitted
spmatmul sparse matrix multiplication - performance improvements (#30372)
* General performance improvements for sparse matmul Details for the polyalgorithm are in: #30372
1 parent b451001 commit fae262c

File tree

2 files changed

+70
-30
lines changed

2 files changed

+70
-30
lines changed

stdlib/SparseArrays/src/linalg.jl

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,63 +147,104 @@ end
147147
*(A::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}, B::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(copy(A), copy(B))
148148
*(A::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}, B::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(copy(A), copy(B))
149149

150-
function spmatmul(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti};
151-
sortindices::Symbol = :sortcols) where {Tv,Ti}
150+
# Gustavsen's matrix multiplication algorithm revisited.
151+
# The result rowval vector is already sorted by construction.
152+
# The auxiliary Vector{Ti} xb is replaced by a Vector{Bool} of same length.
153+
# The optional argument controlling a sorting algorithm is obsolete.
154+
# depending on expected execution speed the sorting of the result column is
155+
# done by a quicksort of the row indices or by a full scan of the dense result vector.
156+
# The last is faster, if more than ≈ 1/32 of the result column is nonzero.
157+
# TODO: extend to SparseMatrixCSCUnion to allow for SubArrays (view(X, :, r)).
158+
function spmatmul(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
152159
mA, nA = size(A)
153-
mB, nB = size(B)
154-
nA==mB || throw(DimensionMismatch())
160+
nB = size(B, 2)
161+
nA == size(B, 1) || throw(DimensionMismatch())
155162

156-
colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
157-
colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval
158-
# TODO: Need better estimation of result space
159-
nnzC = min(mA*nB, length(nzvalA) + length(nzvalB))
163+
rowvalA = rowvals(A); nzvalA = nonzeros(A)
164+
rowvalB = rowvals(B); nzvalB = nonzeros(B)
165+
nnzC = max(estimate_mulsize(mA, nnz(A), nA, nnz(B), nB) * 11 ÷ 10, mA)
160166
colptrC = Vector{Ti}(undef, nB+1)
161167
rowvalC = Vector{Ti}(undef, nnzC)
162168
nzvalC = Vector{Tv}(undef, nnzC)
169+
nzpercol = nnzC ÷ max(nB, 1)
163170

164171
@inbounds begin
165172
ip = 1
166-
xb = zeros(Ti, mA)
167-
x = zeros(Tv, mA)
173+
xb = fill(false, mA)
168174
for i in 1:nB
169175
if ip + mA - 1 > nnzC
170-
resize!(rowvalC, nnzC + max(nnzC,mA))
171-
resize!(nzvalC, nnzC + max(nnzC,mA))
172-
nnzC = length(nzvalC)
176+
nnzC += max(mA, nnzC>>2)
177+
resize!(rowvalC, nnzC)
178+
resize!(nzvalC, nnzC)
173179
end
174-
colptrC[i] = ip
175-
for jp in colptrB[i]:(colptrB[i+1] - 1)
180+
colptrC[i] = ip0 = ip
181+
k0 = ip - 1
182+
for jp in nzrange(B, i)
176183
nzB = nzvalB[jp]
177184
j = rowvalB[jp]
178-
for kp in colptrA[j]:(colptrA[j+1] - 1)
185+
for kp in nzrange(A, j)
179186
nzC = nzvalA[kp] * nzB
180187
k = rowvalA[kp]
181-
if xb[k] != i
188+
if xb[k]
189+
nzvalC[k+k0] += nzC
190+
else
191+
nzvalC[k+k0] = nzC
192+
xb[k] = true
182193
rowvalC[ip] = k
183194
ip += 1
184-
xb[k] = i
185-
x[k] = nzC
186-
else
187-
x[k] += nzC
188195
end
189196
end
190197
end
191-
for vp in colptrC[i]:(ip - 1)
192-
nzvalC[vp] = x[rowvalC[vp]]
198+
if ip > ip0
199+
if prefer_sort(ip-k0, mA)
200+
# in-place sort of indices. Effort: O(nnz*ln(nnz)).
201+
sort!(rowvalC, ip0, ip-1, QuickSort, Base.Order.Forward)
202+
for vp = ip0:ip-1
203+
k = rowvalC[vp]
204+
xb[k] = false
205+
nzvalC[vp] = nzvalC[k+k0]
206+
end
207+
else
208+
# scan result vector (effort O(mA))
209+
for k = 1:mA
210+
if xb[k]
211+
xb[k] = false
212+
rowvalC[ip0] = k
213+
nzvalC[ip0] = nzvalC[k+k0]
214+
ip0 += 1
215+
end
216+
end
217+
end
193218
end
194219
end
195220
colptrC[nB+1] = ip
196221
end
197222

198-
deleteat!(rowvalC, colptrC[end]:length(rowvalC))
199-
deleteat!(nzvalC, colptrC[end]:length(nzvalC))
223+
resize!(rowvalC, ip - 1)
224+
resize!(nzvalC, ip - 1)
200225

201-
# The Gustavson algorithm does not guarantee the product to have sorted row indices.
202-
Cunsorted = SparseMatrixCSC(mA, nB, colptrC, rowvalC, nzvalC)
203-
C = SparseArrays.sortSparseMatrixCSC!(Cunsorted, sortindices=sortindices)
226+
# This modification of Gustavson algorithm has sorted row indices
227+
C = SparseMatrixCSC(mA, nB, colptrC, rowvalC, nzvalC)
204228
return C
205229
end
206230

231+
# estimated number of non-zeros in matrix product
232+
# it is assumed, that the non-zero indices are distributed independently and uniformly
233+
# in both matrices. Over-estimation is possible if that is not the case.
234+
function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer, k::Integer)
235+
p = (nnzA / (m * n)) * (nnzB / (n * k))
236+
p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k
237+
end
238+
239+
# determine if sort! shall be used or the whole column be scanned
240+
# based on empirical data on i7-3610QM CPU
241+
# measuring runtimes of the scanning and sorting loops of the algorithm.
242+
# The parameters 6 and 3 might be modified for different architectures.
243+
prefer_sort(nz::Integer, m::Integer) = m > 6 && 3 * ilog2(nz) * nz < m
244+
245+
# minimal number of bits required to represent integer; ilog2(n) >= log2(n)
246+
ilog2(n::Integer) = sizeof(n)<<3 - leading_zeros(n)
247+
207248
# Frobenius dot/inner product: trace(A'B)
208249
function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2}
209250
m, n = size(A)

stdlib/SparseArrays/test/sparse.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,7 @@ end
322322
a = sprand(10, 5, 0.7)
323323
b = sprand(5, 15, 0.3)
324324
@test maximum(abs.(a*b - Array(a)*Array(b))) < 100*eps()
325-
@test maximum(abs.(SparseArrays.spmatmul(a,b,sortindices=:sortcols) - Array(a)*Array(b))) < 100*eps()
326-
@test maximum(abs.(SparseArrays.spmatmul(a,b,sortindices=:doubletranspose) - Array(a)*Array(b))) < 100*eps()
325+
@test maximum(abs.(SparseArrays.spmatmul(a,b) - Array(a)*Array(b))) < 100*eps()
327326
f = Diagonal(rand(5))
328327
@test Array(a*f) == Array(a)*f
329328
@test Array(f*b) == f*Array(b)

0 commit comments

Comments
 (0)