Skip to content

Commit e76ef5c

Browse files
brkyvzmengxr
authored andcommitted
[SPARK-3418] Sparse Matrix support (CCS) and additional native BLAS operations added
Local `SparseMatrix` support added in Compressed Column Storage (CCS) format in addition to Level-2 and Level-3 BLAS operations such as dgemv and dgemm respectively. BLAS doesn't support sparse matrix operations, therefore support for `SparseMatrix`-`DenseMatrix` multiplication and `SparseMatrix`-`DenseVector` implementations have been added. I will post performance comparisons in the comments momentarily. Author: Burak <[email protected]> Closes apache#2294 from brkyvz/SPARK-3418 and squashes the following commits: 88814ed [Burak] Hopefully fixed MiMa this time 47e49d5 [Burak] really fixed MiMa issue f0bae57 [Burak] [SPARK-3418] Fixed MiMa compatibility issues (excluded from check) 4b7dbec [Burak] 9/17 comments addressed 7af2f83 [Burak] sealed traits Vector and Matrix d3a8a16 [Burak] [SPARK-3418] Squashed missing alpha bug. 421045f [Burak] [SPARK-3418] New code review comments addressed f35a161 [Burak] [SPARK-3418] Code review comments addressed and multiplication further optimized 2508577 [Burak] [SPARK-3418] Fixed one more style issue d16e8a0 [Burak] [SPARK-3418] Fixed style issues and added documentation for methods 204a3f7 [Burak] [SPARK-3418] Fixed failing Matrix unit test 6025297 [Burak] [SPARK-3418] Fixed Scala-style errors dc7be71 [Burak] [SPARK-3418][MLlib] Matrix unit tests expanded with indexing and updating d2d5851 [Burak] [SPARK-3418][MLlib] Sparse Matrix support and additional native BLAS operations added
1 parent e77fa81 commit e76ef5c

File tree

8 files changed

+834
-10
lines changed

8 files changed

+834
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 329 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
package org.apache.spark.mllib.linalg
1919

2020
import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
21+
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
22+
23+
import org.apache.spark.Logging
2124

2225
/**
2326
* BLAS routines for MLlib's vectors and matrices.
2427
*/
25-
private[mllib] object BLAS extends Serializable {
28+
private[mllib] object BLAS extends Serializable with Logging {
2629

2730
@transient private var _f2jBLAS: NetlibBLAS = _
31+
@transient private var _nativeBLAS: NetlibBLAS = _
2832

2933
// For level-1 routines, we use Java implementation.
3034
private def f2jBLAS: NetlibBLAS = {
@@ -197,4 +201,328 @@ private[mllib] object BLAS extends Serializable {
197201
throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
198202
}
199203
}
204+
205+
// For level-3 routines, we use the native BLAS.
206+
private def nativeBLAS: NetlibBLAS = {
207+
if (_nativeBLAS == null) {
208+
_nativeBLAS = NativeBLAS
209+
}
210+
_nativeBLAS
211+
}
212+
213+
/**
214+
* C := alpha * A * B + beta * C
215+
* @param transA whether to use the transpose of matrix A (true), or A itself (false).
216+
* @param transB whether to use the transpose of matrix B (true), or B itself (false).
217+
* @param alpha a scalar to scale the multiplication A * B.
218+
* @param A the matrix A that will be left multiplied to B. Size of m x k.
219+
* @param B the matrix B that will be left multiplied by A. Size of k x n.
220+
* @param beta a scalar that can be used to scale matrix C.
221+
* @param C the resulting matrix C. Size of m x n.
222+
*/
223+
def gemm(
224+
transA: Boolean,
225+
transB: Boolean,
226+
alpha: Double,
227+
A: Matrix,
228+
B: DenseMatrix,
229+
beta: Double,
230+
C: DenseMatrix): Unit = {
231+
if (alpha == 0.0) {
232+
logDebug("gemm: alpha is equal to 0. Returning C.")
233+
} else {
234+
A match {
235+
case sparse: SparseMatrix =>
236+
gemm(transA, transB, alpha, sparse, B, beta, C)
237+
case dense: DenseMatrix =>
238+
gemm(transA, transB, alpha, dense, B, beta, C)
239+
case _ =>
240+
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
241+
}
242+
}
243+
}
244+
245+
/**
246+
* C := alpha * A * B + beta * C
247+
*
248+
* @param alpha a scalar to scale the multiplication A * B.
249+
* @param A the matrix A that will be left multiplied to B. Size of m x k.
250+
* @param B the matrix B that will be left multiplied by A. Size of k x n.
251+
* @param beta a scalar that can be used to scale matrix C.
252+
* @param C the resulting matrix C. Size of m x n.
253+
*/
254+
def gemm(
255+
alpha: Double,
256+
A: Matrix,
257+
B: DenseMatrix,
258+
beta: Double,
259+
C: DenseMatrix): Unit = {
260+
gemm(false, false, alpha, A, B, beta, C)
261+
}
262+
263+
/**
264+
* C := alpha * A * B + beta * C
265+
* For `DenseMatrix` A.
266+
*/
267+
private def gemm(
268+
transA: Boolean,
269+
transB: Boolean,
270+
alpha: Double,
271+
A: DenseMatrix,
272+
B: DenseMatrix,
273+
beta: Double,
274+
C: DenseMatrix): Unit = {
275+
val mA: Int = if (!transA) A.numRows else A.numCols
276+
val nB: Int = if (!transB) B.numCols else B.numRows
277+
val kA: Int = if (!transA) A.numCols else A.numRows
278+
val kB: Int = if (!transB) B.numRows else B.numCols
279+
val tAstr = if (!transA) "N" else "T"
280+
val tBstr = if (!transB) "N" else "T"
281+
282+
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
283+
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
284+
require(nB == C.numCols,
285+
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
286+
287+
nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows,
288+
beta, C.values, C.numRows)
289+
}
290+
291+
/**
292+
* C := alpha * A * B + beta * C
293+
* For `SparseMatrix` A.
294+
*/
295+
private def gemm(
296+
transA: Boolean,
297+
transB: Boolean,
298+
alpha: Double,
299+
A: SparseMatrix,
300+
B: DenseMatrix,
301+
beta: Double,
302+
C: DenseMatrix): Unit = {
303+
val mA: Int = if (!transA) A.numRows else A.numCols
304+
val nB: Int = if (!transB) B.numCols else B.numRows
305+
val kA: Int = if (!transA) A.numCols else A.numRows
306+
val kB: Int = if (!transB) B.numRows else B.numCols
307+
308+
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
309+
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
310+
require(nB == C.numCols,
311+
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
312+
313+
val Avals = A.values
314+
val Arows = if (!transA) A.rowIndices else A.colPtrs
315+
val Acols = if (!transA) A.colPtrs else A.rowIndices
316+
317+
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
318+
if (transA){
319+
var colCounterForB = 0
320+
if (!transB) { // Expensive to put the check inside the loop
321+
while (colCounterForB < nB) {
322+
var rowCounterForA = 0
323+
val Cstart = colCounterForB * mA
324+
val Bstart = colCounterForB * kA
325+
while (rowCounterForA < mA) {
326+
var i = Arows(rowCounterForA)
327+
val indEnd = Arows(rowCounterForA + 1)
328+
var sum = 0.0
329+
while (i < indEnd) {
330+
sum += Avals(i) * B.values(Bstart + Acols(i))
331+
i += 1
332+
}
333+
val Cindex = Cstart + rowCounterForA
334+
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
335+
rowCounterForA += 1
336+
}
337+
colCounterForB += 1
338+
}
339+
} else {
340+
while (colCounterForB < nB) {
341+
var rowCounter = 0
342+
val Cstart = colCounterForB * mA
343+
while (rowCounter < mA) {
344+
var i = Arows(rowCounter)
345+
val indEnd = Arows(rowCounter + 1)
346+
var sum = 0.0
347+
while (i < indEnd) {
348+
sum += Avals(i) * B(colCounterForB, Acols(i))
349+
i += 1
350+
}
351+
val Cindex = Cstart + rowCounter
352+
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
353+
rowCounter += 1
354+
}
355+
colCounterForB += 1
356+
}
357+
}
358+
} else {
359+
// Scale matrix first if `beta` is not equal to 0.0
360+
if (beta != 0.0){
361+
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
362+
}
363+
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
364+
// B, and added to C.
365+
var colCounterForB = 0 // the column to be updated in C
366+
if (!transB) { // Expensive to put the check inside the loop
367+
while (colCounterForB < nB) {
368+
var colCounterForA = 0 // The column of A to multiply with the row of B
369+
val Bstart = colCounterForB * kB
370+
val Cstart = colCounterForB * mA
371+
while (colCounterForA < kA) {
372+
var i = Acols(colCounterForA)
373+
val indEnd = Acols(colCounterForA + 1)
374+
val Bval = B.values(Bstart + colCounterForA) * alpha
375+
while (i < indEnd){
376+
C.values(Cstart + Arows(i)) += Avals(i) * Bval
377+
i += 1
378+
}
379+
colCounterForA += 1
380+
}
381+
colCounterForB += 1
382+
}
383+
} else {
384+
while (colCounterForB < nB) {
385+
var colCounterForA = 0 // The column of A to multiply with the row of B
386+
val Cstart = colCounterForB * mA
387+
while (colCounterForA < kA){
388+
var i = Acols(colCounterForA)
389+
val indEnd = Acols(colCounterForA + 1)
390+
val Bval = B(colCounterForB, colCounterForA) * alpha
391+
while (i < indEnd){
392+
C.values(Cstart + Arows(i)) += Avals(i) * Bval
393+
i += 1
394+
}
395+
colCounterForA += 1
396+
}
397+
colCounterForB += 1
398+
}
399+
}
400+
}
401+
}
402+
403+
/**
404+
* y := alpha * A * x + beta * y
405+
* @param trans whether to use the transpose of matrix A (true), or A itself (false).
406+
* @param alpha a scalar to scale the multiplication A * x.
407+
* @param A the matrix A that will be left multiplied to x. Size of m x n.
408+
* @param x the vector x that will be left multiplied by A. Size of n x 1.
409+
* @param beta a scalar that can be used to scale vector y.
410+
* @param y the resulting vector y. Size of m x 1.
411+
*/
412+
def gemv(
413+
trans: Boolean,
414+
alpha: Double,
415+
A: Matrix,
416+
x: DenseVector,
417+
beta: Double,
418+
y: DenseVector): Unit = {
419+
420+
val mA: Int = if (!trans) A.numRows else A.numCols
421+
val nx: Int = x.size
422+
val nA: Int = if (!trans) A.numCols else A.numRows
423+
424+
require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx")
425+
require(mA == y.size,
426+
s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")
427+
if (alpha == 0.0) {
428+
logDebug("gemv: alpha is equal to 0. Returning y.")
429+
} else {
430+
A match {
431+
case sparse: SparseMatrix =>
432+
gemv(trans, alpha, sparse, x, beta, y)
433+
case dense: DenseMatrix =>
434+
gemv(trans, alpha, dense, x, beta, y)
435+
case _ =>
436+
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
437+
}
438+
}
439+
}
440+
441+
/**
442+
* y := alpha * A * x + beta * y
443+
*
444+
* @param alpha a scalar to scale the multiplication A * x.
445+
* @param A the matrix A that will be left multiplied to x. Size of m x n.
446+
* @param x the vector x that will be left multiplied by A. Size of n x 1.
447+
* @param beta a scalar that can be used to scale vector y.
448+
* @param y the resulting vector y. Size of m x 1.
449+
*/
450+
def gemv(
451+
alpha: Double,
452+
A: Matrix,
453+
x: DenseVector,
454+
beta: Double,
455+
y: DenseVector): Unit = {
456+
gemv(false, alpha, A, x, beta, y)
457+
}
458+
459+
/**
460+
* y := alpha * A * x + beta * y
461+
* For `DenseMatrix` A.
462+
*/
463+
private def gemv(
464+
trans: Boolean,
465+
alpha: Double,
466+
A: DenseMatrix,
467+
x: DenseVector,
468+
beta: Double,
469+
y: DenseVector): Unit = {
470+
val tStrA = if (!trans) "N" else "T"
471+
nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta,
472+
y.values, 1)
473+
}
474+
475+
/**
476+
* y := alpha * A * x + beta * y
477+
* For `SparseMatrix` A.
478+
*/
479+
private def gemv(
480+
trans: Boolean,
481+
alpha: Double,
482+
A: SparseMatrix,
483+
x: DenseVector,
484+
beta: Double,
485+
y: DenseVector): Unit = {
486+
487+
val mA: Int = if(!trans) A.numRows else A.numCols
488+
val nA: Int = if(!trans) A.numCols else A.numRows
489+
490+
val Avals = A.values
491+
val Arows = if (!trans) A.rowIndices else A.colPtrs
492+
val Acols = if (!trans) A.colPtrs else A.rowIndices
493+
494+
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
495+
if (trans){
496+
var rowCounter = 0
497+
while (rowCounter < mA){
498+
var i = Arows(rowCounter)
499+
val indEnd = Arows(rowCounter + 1)
500+
var sum = 0.0
501+
while(i < indEnd){
502+
sum += Avals(i) * x.values(Acols(i))
503+
i += 1
504+
}
505+
y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha
506+
rowCounter += 1
507+
}
508+
} else {
509+
// Scale vector first if `beta` is not equal to 0.0
510+
if (beta != 0.0){
511+
scal(beta, y)
512+
}
513+
// Perform matrix-vector multiplication and add to y
514+
var colCounterForA = 0
515+
while (colCounterForA < nA){
516+
var i = Acols(colCounterForA)
517+
val indEnd = Acols(colCounterForA + 1)
518+
val xVal = x.values(colCounterForA) * alpha
519+
while (i < indEnd){
520+
val rowIndex = Arows(i)
521+
y.values(rowIndex) += Avals(i) * xVal
522+
i += 1
523+
}
524+
colCounterForA += 1
525+
}
526+
}
527+
}
200528
}

0 commit comments

Comments
 (0)