Skip to content

Commit c069507

Browse files
committed
Add SparseVector support for gemv with DenseMatrix.
1 parent d3db2fd commit c069507

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
463463
def gemv(
464464
alpha: Double,
465465
A: Matrix,
466-
x: DenseVector,
466+
x: Vector,
467467
beta: Double,
468468
y: DenseVector): Unit = {
469469
require(A.numCols == x.size,
@@ -473,13 +473,16 @@ private[spark] object BLAS extends Serializable with Logging {
473473
if (alpha == 0.0) {
474474
logDebug("gemv: alpha is equal to 0. Returning y.")
475475
} else {
476-
A match {
477-
case sparse: SparseMatrix =>
478-
gemv(alpha, sparse, x, beta, y)
479-
case dense: DenseMatrix =>
480-
gemv(alpha, dense, x, beta, y)
476+
(A, x) match {
477+
case (sparse: SparseMatrix, dx: DenseVector) =>
478+
gemv(alpha, sparse, dx, beta, y)
479+
case (dense: DenseMatrix, dx: DenseVector) =>
480+
gemv(alpha, dense, dx, beta, y)
481+
case (dense: DenseMatrix, sx: SparseVector) =>
482+
gemv(alpha, dense, sx, beta, y)
481483
case _ =>
482-
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
484+
throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
485+
s"${A.getClass} and vector type ${x.getClass}.")
483486
}
484487
}
485488
}
@@ -500,6 +503,55 @@ private[spark] object BLAS extends Serializable with Logging {
500503
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
501504
y.values, 1)
502505
}
506+
507+
/**
508+
* y := alpha * A * x + beta * y
509+
* For `DenseMatrix` A and SparseVector x.
510+
*/
511+
private def gemv(
512+
alpha: Double,
513+
A: DenseMatrix,
514+
x: SparseVector,
515+
beta: Double,
516+
y: DenseVector): Unit = {
517+
val mA: Int = A.numRows
518+
val nA: Int = A.numCols
519+
520+
val Avals = A.values
521+
var colCounterForA = 0
522+
523+
var xIndices = x.indices
524+
var xNnz = xIndices.size
525+
var xValues = x.values
526+
527+
scal(beta, y)
528+
529+
if (!A.isTransposed) {
530+
var rowCounterForA = 0
531+
while (rowCounterForA < mA) {
532+
var sum = 0.0
533+
var k = 0
534+
while (k < xNnz) {
535+
sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
536+
k += 1
537+
}
538+
y.values(rowCounterForA) += sum * alpha
539+
rowCounterForA += 1
540+
}
541+
} else {
542+
var rowCounterForA = 0
543+
while (rowCounterForA < mA) {
544+
var sum = 0.0
545+
var k = 0
546+
while (k < xNnz) {
547+
sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
548+
k += 1
549+
}
550+
y.values(rowCounterForA) += sum * alpha
551+
rowCounterForA += 1
552+
}
553+
}
554+
}
503555

504556
/**
505557
* y := alpha * A * x + beta * y

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,12 @@ class BLASSuite extends FunSuite {
257257
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
258258
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
259259

260-
val x = new DenseVector(Array(1.0, 2.0, 3.0))
260+
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
261+
val sx = dx.toSparse
261262
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
262263

263-
assert(dA.multiply(x) ~== expected absTol 1e-15)
264-
assert(sA.multiply(x) ~== expected absTol 1e-15)
264+
assert(dA.multiply(dx) ~== expected absTol 1e-15)
265+
assert(sA.multiply(dx) ~== expected absTol 1e-15)
265266

266267
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
267268
val y2 = y1.copy
@@ -270,17 +271,26 @@ class BLASSuite extends FunSuite {
270271
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
271272
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
272273

273-
gemv(1.0, dA, x, 2.0, y1)
274-
gemv(1.0, sA, x, 2.0, y2)
275-
gemv(2.0, dA, x, 2.0, y3)
276-
gemv(2.0, sA, x, 2.0, y4)
274+
gemv(1.0, dA, dx, 2.0, y1)
275+
gemv(1.0, sA, dx, 2.0, y2)
276+
gemv(2.0, dA, dx, 2.0, y3)
277+
gemv(2.0, sA, dx, 2.0, y4)
277278
assert(y1 ~== expected2 absTol 1e-15)
278279
assert(y2 ~== expected2 absTol 1e-15)
279280
assert(y3 ~== expected3 absTol 1e-15)
280281
assert(y4 ~== expected3 absTol 1e-15)
282+
283+
val y1_copy = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
284+
val y3_copy = y1_copy.copy
285+
286+
gemv(1.0, dA, sx, 2.0, y1_copy)
287+
gemv(2.0, dA, sx, 2.0, y3_copy)
288+
assert(y1_copy ~== expected2 absTol 1e-15)
289+
assert(y3_copy ~== expected3 absTol 1e-15)
290+
281291
withClue("columns of A don't match the rows of B") {
282292
intercept[Exception] {
283-
gemv(1.0, dA.transpose, x, 2.0, y1)
293+
gemv(1.0, dA.transpose, dx, 2.0, y1)
284294
}
285295
}
286296
val dAT =
@@ -291,7 +301,7 @@ class BLASSuite extends FunSuite {
291301
val dATT = dAT.transpose
292302
val sATT = sAT.transpose
293303

294-
assert(dATT.multiply(x) ~== expected absTol 1e-15)
295-
assert(sATT.multiply(x) ~== expected absTol 1e-15)
304+
assert(dATT.multiply(dx) ~== expected absTol 1e-15)
305+
assert(sATT.multiply(dx) ~== expected absTol 1e-15)
296306
}
297307
}

0 commit comments

Comments
 (0)