17
17
18
18
package org .apache .spark .mllib .classification
19
19
20
- import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum }
20
+ import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum , Axis }
21
+ import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
21
22
22
23
import org .apache .spark .{SparkException , Logging }
23
24
import org .apache .spark .SparkContext ._
24
25
import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
25
26
import org .apache .spark .mllib .regression .LabeledPoint
26
27
import org .apache .spark .rdd .RDD
27
28
29
+
30
+ /**
31
+ *
32
+ */
33
+ object NaiveBayesModels extends Enumeration {
34
+ type NaiveBayesModels = Value
35
+ val Multinomial, Bernoulli = Value
36
+ }
37
+
28
38
/**
29
39
* Model for Naive Bayes Classifiers.
30
40
*
31
41
* @param labels list of labels
32
42
* @param pi log of class priors, whose dimension is C, number of labels
33
43
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
34
44
* where D is number of features
45
+ * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
46
+ * Multinomial or Bernoulli
35
47
*/
48
+
36
49
class NaiveBayesModel private [mllib] (
37
50
val labels : Array [Double ],
38
51
val pi : Array [Double ],
39
- val theta : Array [Array [Double ]]) extends ClassificationModel with Serializable {
40
-
41
- private val brzPi = new BDV [Double ](pi)
42
- private val brzTheta = new BDM [Double ](theta.length, theta(0 ).length)
52
+ val theta : Array [Array [Double ]],
53
+ val model : NaiveBayesModels ) extends ClassificationModel with Serializable {
43
54
44
- {
45
- // Need to put an extra pair of braces to prevent Scala treating `i` as a member.
55
+ def populateMatrix (arrayIn : Array [Array [Double ]],
56
+ matrixIn : BDM [Double ],
57
+ transformation : (Double ) => Double = (x) => x) = {
46
58
var i = 0
47
- while (i < theta .length) {
59
+ while (i < arrayIn .length) {
48
60
var j = 0
49
- while (j < theta (i).length) {
50
- brzTheta (i, j) = theta(i)(j)
61
+ while (j < arrayIn (i).length) {
62
+ matrixIn (i, j) = transformation( theta(i)(j) )
51
63
j += 1
52
64
}
53
65
i += 1
54
66
}
55
67
}
56
68
69
+ private val brzPi = new BDV [Double ](pi)
70
+ private val brzTheta = new BDM [Double ](theta.length, theta(0 ).length)
71
+ populateMatrix(theta, brzTheta)
72
+
73
+ private val brzNegTheta : Option [BDM [Double ]] = model match {
74
+ case NaiveBayesModels .Multinomial => None
75
+ case NaiveBayesModels .Bernoulli =>
76
+ val negTheta = new BDM [Double ](theta.length, theta(0 ).length)
77
+ populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x)))
78
+ Option (negTheta)
79
+ }
80
+
57
81
override def predict (testData : RDD [Vector ]): RDD [Double ] = {
58
82
val bcModel = testData.context.broadcast(this )
59
83
testData.mapPartitions { iter =>
@@ -63,7 +87,14 @@ class NaiveBayesModel private[mllib] (
63
87
}
64
88
65
89
override def predict (testData : Vector ): Double = {
66
- labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
90
+ model match {
91
+ case NaiveBayesModels .Multinomial =>
92
+ labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
93
+ case NaiveBayesModels .Bernoulli =>
94
+ labels (brzArgmax (brzPi +
95
+ (brzTheta - brzNegTheta.get) * testData.toBreeze +
96
+ brzSum(brzNegTheta.get, Axis ._1)))
97
+ }
67
98
}
68
99
}
69
100
@@ -75,16 +106,26 @@ class NaiveBayesModel private[mllib] (
75
106
* document classification. By making every vector a 0-1 vector, it can also be used as
76
107
* Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
77
108
*/
78
- class NaiveBayes private (private var lambda : Double ) extends Serializable with Logging {
109
+ class NaiveBayes private (private var lambda : Double ,
110
+ var model : NaiveBayesModels ) extends Serializable with Logging {
79
111
80
- def this () = this (1.0 )
112
+ def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
113
+
114
+ def this () = this (1.0 , NaiveBayesModels .Multinomial )
81
115
82
116
/** Set the smoothing parameter. Default: 1.0. */
83
117
def setLambda (lambda : Double ): NaiveBayes = {
84
118
this .lambda = lambda
85
119
this
86
120
}
87
121
122
+ /** Set the model type. Default: Multinomial. */
123
+ def setModelType (model : NaiveBayesModels ): NaiveBayes = {
124
+ this .model = model
125
+ this
126
+ }
127
+
128
+
88
129
/**
89
130
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
90
131
*
@@ -118,21 +159,27 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
118
159
mergeCombiners = (c1 : (Long , BDV [Double ]), c2 : (Long , BDV [Double ])) =>
119
160
(c1._1 + c2._1, c1._2 += c2._2)
120
161
).collect()
162
+
121
163
val numLabels = aggregated.length
122
164
var numDocuments = 0L
123
165
aggregated.foreach { case (_, (n, _)) =>
124
166
numDocuments += n
125
167
}
126
168
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
169
+
127
170
val labels = new Array [Double ](numLabels)
128
171
val pi = new Array [Double ](numLabels)
129
172
val theta = Array .fill(numLabels)(new Array [Double ](numFeatures))
173
+
130
174
val piLogDenom = math.log(numDocuments + numLabels * lambda)
131
175
var i = 0
132
176
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
133
177
labels(i) = label
134
- val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
135
178
pi(i) = math.log(n + lambda) - piLogDenom
179
+ val thetaLogDenom = model match {
180
+ case NaiveBayesModels .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
181
+ case NaiveBayesModels .Bernoulli => math.log(n + 2.0 * lambda)
182
+ }
136
183
var j = 0
137
184
while (j < numFeatures) {
138
185
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
@@ -141,7 +188,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
141
188
i += 1
142
189
}
143
190
144
- new NaiveBayesModel (labels, pi, theta)
191
+ new NaiveBayesModel (labels, pi, theta, model )
145
192
}
146
193
}
147
194
@@ -154,8 +201,7 @@ object NaiveBayes {
154
201
*
155
202
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
156
203
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
157
- * document classification. By making every vector a 0-1 vector, it can also be used as
158
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
204
+ * document classification.
159
205
*
160
206
* This version of the method uses a default smoothing parameter of 1.0.
161
207
*
@@ -171,8 +217,7 @@ object NaiveBayes {
171
217
*
172
218
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
173
219
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
174
- * document classification. By making every vector a 0-1 vector, it can also be used as
175
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
220
+ * document classification.
176
221
*
177
222
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
178
223
* vector or a count vector.
@@ -181,4 +226,25 @@ object NaiveBayes {
181
226
def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
182
227
new NaiveBayes (lambda).run(input)
183
228
}
229
+
230
+
231
+ /**
232
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
233
+ *
234
+ * This is by default the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle
235
+ * all kinds of discrete data. For example, by converting documents into TF-IDF vectors,
236
+ * it can be used for document classification. By making every vector a 0-1 vector and
237
+ * setting the model type to NaiveBayesModels.Bernoulli, it fits and predicts as
238
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
239
+ *
240
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
241
+ * vector or a count vector.
242
+ * @param lambda The smoothing parameter
243
+ *
244
+ * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
245
+ * Multinomial or Bernoulli
246
+ */
247
+ def train (input : RDD [LabeledPoint ], lambda : Double , model : NaiveBayesModels ): NaiveBayesModel = {
248
+ new NaiveBayes (lambda, model).run(input)
249
+ }
184
250
}
0 commit comments