Skip to content

Commit c12dff9

Browse files
viiryamengxr
authored andcommitted
[SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS
JIRA: https://issues.apache.org/jira/browse/SPARK-7652 Author: Liang-Chi Hsieh <[email protected]> Closes #6189 from viirya/naive_bayes_blas_prediction and squashes the following commits: ab611fd [Liang-Chi Hsieh] Remove unnecessary space. ddc48b9 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into naive_bayes_blas_prediction b5772b4 [Liang-Chi Hsieh] Fix binary compatibility. 2f65186 [Liang-Chi Hsieh] Remove toDense. 1b6cdfe [Liang-Chi Hsieh] Update the implementation of naive Bayes prediction with BLAS.
1 parent 68fb2a4 commit c12dff9

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable}
2121

2222
import scala.collection.JavaConverters._
2323

24-
import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
25-
import breeze.numerics.{exp => brzExp, log => brzLog}
2624
import org.json4s.JsonDSL._
2725
import org.json4s.jackson.JsonMethods._
2826

2927
import org.apache.spark.{Logging, SparkContext, SparkException}
30-
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
28+
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
3129
import org.apache.spark.mllib.regression.LabeledPoint
3230
import org.apache.spark.mllib.util.{Loader, Saveable}
3331
import org.apache.spark.rdd.RDD
@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] (
5048
val modelType: String)
5149
extends ClassificationModel with Serializable with Saveable {
5250

51+
private val piVector = new DenseVector(pi)
52+
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
53+
5354
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
5455
this(labels, pi, theta, "Multinomial")
5556

@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] (
6061
theta: JIterable[JIterable[Double]]) =
6162
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
6263

63-
private val brzPi = new BDV[Double](pi)
64-
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
65-
6664
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
67-
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
65+
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6866
// application of this condition (in predict function).
69-
private val (brzNegTheta, brzNegThetaSum) = modelType match {
67+
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
7068
case "Multinomial" => (None, None)
7169
case "Bernoulli" =>
72-
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
73-
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
70+
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
71+
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
72+
val thetaMinusNegTheta = thetaMatrix.map { value =>
73+
value - math.log(1.0 - math.exp(value))
74+
}
75+
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
7476
case _ =>
7577
// This should never happen.
7678
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] (
8587
}
8688

8789
override def predict(testData: Vector): Double = {
88-
val brzData = testData.toBreeze
8990
modelType match {
9091
case "Multinomial" =>
91-
labels(brzArgmax(brzPi + brzTheta * brzData))
92+
val prob = thetaMatrix.multiply(testData)
93+
BLAS.axpy(1.0, piVector, prob)
94+
labels(prob.argmax)
9295
case "Bernoulli" =>
93-
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
94-
throw new SparkException(
95-
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
96+
testData.foreachActive { (index, value) =>
97+
if (value != 0.0 && value != 1.0) {
98+
throw new SparkException(
99+
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
100+
}
96101
}
97-
labels(brzArgmax(brzPi +
98-
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
102+
val prob = thetaMinusNegTheta.get.multiply(testData)
103+
BLAS.axpy(1.0, piVector, prob)
104+
BLAS.axpy(1.0, negThetaSum.get, prob)
105+
labels(prob.argmax)
99106
case _ =>
100107
// This should never happen.
101108
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")

0 commit comments

Comments
 (0)