@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable}
21
21
22
22
import scala .collection .JavaConverters ._
23
23
24
- import breeze .linalg .{Axis , DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum }
25
- import breeze .numerics .{exp => brzExp , log => brzLog }
26
24
import org .json4s .JsonDSL ._
27
25
import org .json4s .jackson .JsonMethods ._
28
26
29
27
import org .apache .spark .{Logging , SparkContext , SparkException }
30
- import org .apache .spark .mllib .linalg .{BLAS , DenseVector , SparseVector , Vector }
28
+ import org .apache .spark .mllib .linalg .{BLAS , DenseMatrix , DenseVector , SparseVector , Vector , Vectors }
31
29
import org .apache .spark .mllib .regression .LabeledPoint
32
30
import org .apache .spark .mllib .util .{Loader , Saveable }
33
31
import org .apache .spark .rdd .RDD
@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] (
50
48
val modelType : String )
51
49
extends ClassificationModel with Serializable with Saveable {
52
50
51
+ private val piVector = new DenseVector (pi)
52
+ private val thetaMatrix = new DenseMatrix (labels.size, theta(0 ).size, theta.flatten, true )
53
+
53
54
private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
54
55
this (labels, pi, theta, " Multinomial" )
55
56
@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] (
60
61
theta : JIterable [JIterable [Double ]]) =
61
62
this (labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
62
63
63
- private val brzPi = new BDV [Double ](pi)
64
- private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
65
-
66
64
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
67
- // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
65
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
68
66
// application of this condition (in predict function).
69
- private val (brzNegTheta, brzNegThetaSum ) = modelType match {
67
+ private val (thetaMinusNegTheta, negThetaSum ) = modelType match {
70
68
case " Multinomial" => (None , None )
71
69
case " Bernoulli" =>
72
- val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
73
- (Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
70
+ val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
71
+ val ones = new DenseVector (Array .fill(thetaMatrix.numCols){1.0 })
72
+ val thetaMinusNegTheta = thetaMatrix.map { value =>
73
+ value - math.log(1.0 - math.exp(value))
74
+ }
75
+ (Option (thetaMinusNegTheta), Option (negTheta.multiply(ones)))
74
76
case _ =>
75
77
// This should never happen.
76
78
throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] (
85
87
}
86
88
87
89
override def predict (testData : Vector ): Double = {
88
- val brzData = testData.toBreeze
89
90
modelType match {
90
91
case " Multinomial" =>
91
- labels(brzArgmax(brzPi + brzTheta * brzData))
92
+ val prob = thetaMatrix.multiply(testData)
93
+ BLAS .axpy(1.0 , piVector, prob)
94
+ labels(prob.argmax)
92
95
case " Bernoulli" =>
93
- if (! brzData.forall(v => v == 0.0 || v == 1.0 )) {
94
- throw new SparkException (
95
- s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData. " )
96
+ testData.foreachActive { (index, value) =>
97
+ if (value != 0.0 && value != 1.0 ) {
98
+ throw new SparkException (
99
+ s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData. " )
100
+ }
96
101
}
97
- labels(brzArgmax(brzPi +
98
- (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
102
+ val prob = thetaMinusNegTheta.get.multiply(testData)
103
+ BLAS .axpy(1.0 , piVector, prob)
104
+ BLAS .axpy(1.0 , negThetaSum.get, prob)
105
+ labels(prob.argmax)
99
106
case _ =>
100
107
// This should never happen.
101
108
throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
0 commit comments