Skip to content

Commit 1b6cdfe

Browse files
committed
Update the implementation of naive Bayes prediction with BLAS.
1 parent d3db2fd commit 1b6cdfe

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s.JsonDSL._
2727
import org.json4s.jackson.JsonMethods._
2828

2929
import org.apache.spark.{Logging, SparkContext, SparkException}
30-
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
30+
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
3131
import org.apache.spark.mllib.regression.LabeledPoint
3232
import org.apache.spark.mllib.util.{Loader, Saveable}
3333
import org.apache.spark.rdd.RDD
@@ -50,6 +50,9 @@ class NaiveBayesModel private[mllib] (
5050
val modelType: String)
5151
extends ClassificationModel with Serializable with Saveable {
5252

53+
val piVector = Vectors.dense(pi).asInstanceOf[DenseVector]
54+
val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
55+
5356
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
5457
this(labels, pi, theta, "Multinomial")
5558

@@ -60,17 +63,18 @@ class NaiveBayesModel private[mllib] (
6063
theta: JIterable[JIterable[Double]]) =
6164
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
6265

63-
private val brzPi = new BDV[Double](pi)
64-
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
65-
6666
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
67-
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
67+
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6868
// application of this condition (in predict function).
69-
private val (brzNegTheta, brzNegThetaSum) = modelType match {
69+
private val (thetaMinusnegTheta, brzNegThetaSum) = modelType match {
7070
case "Multinomial" => (None, None)
7171
case "Bernoulli" =>
72-
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
73-
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
72+
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
73+
val ones = Vectors.dense(Array.fill(thetaMatrix.numCols){1.0}).asInstanceOf[DenseVector]
74+
val thetaMinusnegTheta = thetaMatrix.map { value =>
75+
value - math.log(1.0 - math.exp(value))
76+
}
77+
(Option(thetaMinusnegTheta), Option(negTheta.multiply(ones)))
7478
case _ =>
7579
// This should never happen.
7680
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -85,17 +89,22 @@ class NaiveBayesModel private[mllib] (
8589
}
8690

8791
override def predict(testData: Vector): Double = {
88-
val brzData = testData.toBreeze
8992
modelType match {
9093
case "Multinomial" =>
91-
labels(brzArgmax(brzPi + brzTheta * brzData))
94+
val prob = thetaMatrix.multiply(testData.toDense)
95+
BLAS.axpy(1.0, piVector, prob)
96+
labels(prob.argmax)
9297
case "Bernoulli" =>
93-
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
94-
throw new SparkException(
95-
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
98+
testData.foreachActive { (index, value) =>
99+
if (value != 0.0 && value != 1.0) {
100+
throw new SparkException(
101+
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
102+
}
96103
}
97-
labels(brzArgmax(brzPi +
98-
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
104+
val prob = thetaMinusnegTheta.get.multiply(testData.toDense)
105+
BLAS.axpy(1.0, piVector, prob)
106+
BLAS.axpy(1.0, brzNegThetaSum.get, prob)
107+
labels(prob.argmax)
99108
case _ =>
100109
// This should never happen.
101110
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ class DenseMatrix(
273273

274274
override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
275275

276-
private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f))
276+
private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
277+
isTransposed)
277278

278279
private[mllib] def update(f: Double => Double): DenseMatrix = {
279280
val len = values.length
@@ -535,7 +536,7 @@ class SparseMatrix(
535536
}
536537

537538
private[mllib] def map(f: Double => Double) =
538-
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f))
539+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
539540

540541
private[mllib] def update(f: Double => Double): SparseMatrix = {
541542
val len = values.length

0 commit comments

Comments
 (0)