Skip to content

Commit 2f65186

Browse files
committed
Remove toDense.
1 parent 1b6cdfe commit 2f65186

File tree

4 files changed

+89
-29
lines changed

4 files changed

+89
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import java.lang.{Iterable => JIterable}
2121

2222
import scala.collection.JavaConverters._
2323

24-
import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
25-
import breeze.numerics.{exp => brzExp, log => brzLog}
2624
import org.json4s.JsonDSL._
2725
import org.json4s.jackson.JsonMethods._
2826

@@ -50,8 +48,8 @@ class NaiveBayesModel private[mllib] (
5048
val modelType: String)
5149
extends ClassificationModel with Serializable with Saveable {
5250

53-
val piVector = Vectors.dense(pi).asInstanceOf[DenseVector]
54-
val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
51+
private val piVector = new DenseVector(pi)
52+
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
5553

5654
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
5755
this(labels, pi, theta, "Multinomial")
@@ -66,15 +64,15 @@ class NaiveBayesModel private[mllib] (
6664
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
6765
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6866
// application of this condition (in predict function).
69-
private val (thetaMinusnegTheta, brzNegThetaSum) = modelType match {
67+
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
7068
case "Multinomial" => (None, None)
7169
case "Bernoulli" =>
7270
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
73-
val ones = Vectors.dense(Array.fill(thetaMatrix.numCols){1.0}).asInstanceOf[DenseVector]
74-
val thetaMinusnegTheta = thetaMatrix.map { value =>
71+
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
72+
val thetaMinusNegTheta = thetaMatrix.map { value =>
7573
value - math.log(1.0 - math.exp(value))
7674
}
77-
(Option(thetaMinusnegTheta), Option(negTheta.multiply(ones)))
75+
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
7876
case _ =>
7977
// This should never happen.
8078
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -91,7 +89,7 @@ class NaiveBayesModel private[mllib] (
9189
override def predict(testData: Vector): Double = {
9290
modelType match {
9391
case "Multinomial" =>
94-
val prob = thetaMatrix.multiply(testData.toDense)
92+
val prob = thetaMatrix.multiply(testData)
9593
BLAS.axpy(1.0, piVector, prob)
9694
labels(prob.argmax)
9795
case "Bernoulli" =>
@@ -101,9 +99,9 @@ class NaiveBayesModel private[mllib] (
10199
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
102100
}
103101
}
104-
val prob = thetaMinusnegTheta.get.multiply(testData.toDense)
102+
val prob = thetaMinusNegTheta.get.multiply(testData)
105103
BLAS.axpy(1.0, piVector, prob)
106-
BLAS.axpy(1.0, brzNegThetaSum.get, prob)
104+
BLAS.axpy(1.0, negThetaSum.get, prob)
107105
labels(prob.argmax)
108106
case _ =>
109107
// This should never happen.

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
463463
def gemv(
464464
alpha: Double,
465465
A: Matrix,
466-
x: DenseVector,
466+
x: Vector,
467467
beta: Double,
468468
y: DenseVector): Unit = {
469469
require(A.numCols == x.size,
@@ -473,13 +473,16 @@ private[spark] object BLAS extends Serializable with Logging {
473473
if (alpha == 0.0) {
474474
logDebug("gemv: alpha is equal to 0. Returning y.")
475475
} else {
476-
A match {
477-
case sparse: SparseMatrix =>
478-
gemv(alpha, sparse, x, beta, y)
479-
case dense: DenseMatrix =>
480-
gemv(alpha, dense, x, beta, y)
476+
(A, x) match {
477+
case (sparse: SparseMatrix, dx: DenseVector) =>
478+
gemv(alpha, sparse, dx, beta, y)
479+
case (dense: DenseMatrix, dx: DenseVector) =>
480+
gemv(alpha, dense, dx, beta, y)
481+
case (dense: DenseMatrix, sx: SparseVector) =>
482+
gemv(alpha, dense, sx, beta, y)
481483
case _ =>
482-
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
484+
throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
485+
s"${A.getClass} and vector type ${x.getClass}.")
483486
}
484487
}
485488
}
@@ -500,6 +503,55 @@ private[spark] object BLAS extends Serializable with Logging {
500503
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
501504
y.values, 1)
502505
}
506+
507+
/**
508+
* y := alpha * A * x + beta * y
509+
* For `DenseMatrix` A and SparseVector x.
510+
*/
511+
private def gemv(
512+
alpha: Double,
513+
A: DenseMatrix,
514+
x: SparseVector,
515+
beta: Double,
516+
y: DenseVector): Unit = {
517+
val mA: Int = A.numRows
518+
val nA: Int = A.numCols
519+
520+
val Avals = A.values
521+
var colCounterForA = 0
522+
523+
var xIndices = x.indices
524+
var xNnz = xIndices.size
525+
var xValues = x.values
526+
527+
scal(beta, y)
528+
529+
if (!A.isTransposed) {
530+
var rowCounterForA = 0
531+
while (rowCounterForA < mA) {
532+
var sum = 0.0
533+
var k = 0
534+
while (k < xNnz) {
535+
sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
536+
k += 1
537+
}
538+
y.values(rowCounterForA) += sum * alpha
539+
rowCounterForA += 1
540+
}
541+
} else {
542+
var rowCounterForA = 0
543+
while (rowCounterForA < mA) {
544+
var sum = 0.0
545+
var k = 0
546+
while (k < xNnz) {
547+
sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
548+
k += 1
549+
}
550+
y.values(rowCounterForA) += sum * alpha
551+
rowCounterForA += 1
552+
}
553+
}
554+
}
503555

504556
/**
505557
* y := alpha * A * x + beta * y

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ sealed trait Matrix extends Serializable {
7878
}
7979

8080
/** Convenience method for `Matrix`-`DenseVector` multiplication. */
81-
def multiply(y: DenseVector): DenseVector = {
81+
def multiply(y: Vector): DenseVector = {
8282
val output = new DenseVector(new Array[Double](numRows))
8383
BLAS.gemv(1.0, this, y, 0.0, output)
8484
output

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,12 @@ class BLASSuite extends FunSuite {
257257
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
258258
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
259259

260-
val x = new DenseVector(Array(1.0, 2.0, 3.0))
260+
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
261+
val sx = dx.toSparse
261262
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
262263

263-
assert(dA.multiply(x) ~== expected absTol 1e-15)
264-
assert(sA.multiply(x) ~== expected absTol 1e-15)
264+
assert(dA.multiply(dx) ~== expected absTol 1e-15)
265+
assert(sA.multiply(dx) ~== expected absTol 1e-15)
265266

266267
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
267268
val y2 = y1.copy
@@ -270,17 +271,26 @@ class BLASSuite extends FunSuite {
270271
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
271272
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
272273

273-
gemv(1.0, dA, x, 2.0, y1)
274-
gemv(1.0, sA, x, 2.0, y2)
275-
gemv(2.0, dA, x, 2.0, y3)
276-
gemv(2.0, sA, x, 2.0, y4)
274+
gemv(1.0, dA, dx, 2.0, y1)
275+
gemv(1.0, sA, dx, 2.0, y2)
276+
gemv(2.0, dA, dx, 2.0, y3)
277+
gemv(2.0, sA, dx, 2.0, y4)
277278
assert(y1 ~== expected2 absTol 1e-15)
278279
assert(y2 ~== expected2 absTol 1e-15)
279280
assert(y3 ~== expected3 absTol 1e-15)
280281
assert(y4 ~== expected3 absTol 1e-15)
282+
283+
val y1_copy = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
284+
val y3_copy = y1_copy.copy
285+
286+
gemv(1.0, dA, sx, 2.0, y1_copy)
287+
gemv(2.0, dA, sx, 2.0, y3_copy)
288+
assert(y1_copy ~== expected2 absTol 1e-15)
289+
assert(y3_copy ~== expected3 absTol 1e-15)
290+
281291
withClue("columns of A don't match the rows of B") {
282292
intercept[Exception] {
283-
gemv(1.0, dA.transpose, x, 2.0, y1)
293+
gemv(1.0, dA.transpose, dx, 2.0, y1)
284294
}
285295
}
286296
val dAT =
@@ -291,7 +301,7 @@ class BLASSuite extends FunSuite {
291301
val dATT = dAT.transpose
292302
val sATT = sAT.transpose
293303

294-
assert(dATT.multiply(x) ~== expected absTol 1e-15)
295-
assert(sATT.multiply(x) ~== expected absTol 1e-15)
304+
assert(dATT.multiply(dx) ~== expected absTol 1e-15)
305+
assert(sATT.multiply(dx) ~== expected absTol 1e-15)
296306
}
297307
}

0 commit comments

Comments
 (0)