@@ -27,7 +27,7 @@ import org.json4s.JsonDSL._
27
27
import org .json4s .jackson .JsonMethods ._
28
28
29
29
import org .apache .spark .{Logging , SparkContext , SparkException }
30
- import org .apache .spark .mllib .linalg .{BLAS , DenseVector , SparseVector , Vector }
30
+ import org .apache .spark .mllib .linalg .{BLAS , DenseMatrix , DenseVector , SparseVector , Vector , Vectors }
31
31
import org .apache .spark .mllib .regression .LabeledPoint
32
32
import org .apache .spark .mllib .util .{Loader , Saveable }
33
33
import org .apache .spark .rdd .RDD
@@ -50,6 +50,9 @@ class NaiveBayesModel private[mllib] (
50
50
val modelType : String )
51
51
extends ClassificationModel with Serializable with Saveable {
52
52
53
+ val piVector = Vectors .dense(pi).asInstanceOf [DenseVector ]
54
+ val thetaMatrix = new DenseMatrix (labels.size, theta(0 ).size, theta.flatten, true )
55
+
53
56
private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
54
57
this (labels, pi, theta, " Multinomial" )
55
58
@@ -60,17 +63,18 @@ class NaiveBayesModel private[mllib] (
60
63
theta : JIterable [JIterable [Double ]]) =
61
64
this (labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
62
65
63
- private val brzPi = new BDV [Double ](pi)
64
- private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
65
-
66
66
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
67
- // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
67
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
68
68
// application of this condition (in predict function).
69
- private val (brzNegTheta , brzNegThetaSum) = modelType match {
69
+ private val (thetaMinusnegTheta , brzNegThetaSum) = modelType match {
70
70
case " Multinomial" => (None , None )
71
71
case " Bernoulli" =>
72
- val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
73
- (Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
72
+ val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
73
+ val ones = Vectors .dense(Array .fill(thetaMatrix.numCols){1.0 }).asInstanceOf [DenseVector ]
74
+ val thetaMinusnegTheta = thetaMatrix.map { value =>
75
+ value - math.log(1.0 - math.exp(value))
76
+ }
77
+ (Option (thetaMinusnegTheta), Option (negTheta.multiply(ones)))
74
78
case _ =>
75
79
// This should never happen.
76
80
throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
@@ -85,17 +89,22 @@ class NaiveBayesModel private[mllib] (
85
89
}
86
90
87
91
override def predict (testData : Vector ): Double = {
88
- val brzData = testData.toBreeze
89
92
modelType match {
90
93
case " Multinomial" =>
91
- labels(brzArgmax(brzPi + brzTheta * brzData))
94
+ val prob = thetaMatrix.multiply(testData.toDense)
95
+ BLAS .axpy(1.0 , piVector, prob)
96
+ labels(prob.argmax)
92
97
case " Bernoulli" =>
93
- if (! brzData.forall(v => v == 0.0 || v == 1.0 )) {
94
- throw new SparkException (
95
- s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData. " )
98
+ testData.foreachActive { (index, value) =>
99
+ if (value != 0.0 && value != 1.0 ) {
100
+ throw new SparkException (
101
+ s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData. " )
102
+ }
96
103
}
97
- labels(brzArgmax(brzPi +
98
- (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
104
+ val prob = thetaMinusnegTheta.get.multiply(testData.toDense)
105
+ BLAS .axpy(1.0 , piVector, prob)
106
+ BLAS .axpy(1.0 , brzNegThetaSum.get, prob)
107
+ labels(prob.argmax)
99
108
case _ =>
100
109
// This should never happen.
101
110
throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
0 commit comments