@@ -25,16 +25,17 @@ import org.json4s.jackson.JsonMethods._
25
25
import org .json4s .{DefaultFormats , JValue }
26
26
27
27
import org .apache .spark .{Logging , SparkContext , SparkException }
28
+ import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
28
29
import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
29
30
import org .apache .spark .mllib .regression .LabeledPoint
30
- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
31
31
import org .apache .spark .mllib .util .{Loader , Saveable }
32
32
import org .apache .spark .rdd .RDD
33
33
import org .apache .spark .sql .{DataFrame , SQLContext }
34
34
35
35
36
36
/**
37
- *
37
+ * Model types supported in Naive Bayes:
38
+ * multinomial and Bernoulli currently supported
38
39
*/
39
40
object NaiveBayesModels extends Enumeration {
40
41
type NaiveBayesModels = Value
@@ -45,6 +46,8 @@ object NaiveBayesModels extends Enumeration {
45
46
}
46
47
}
47
48
49
+
50
+
48
51
/**
49
52
* Model for Naive Bayes Classifiers.
50
53
*
@@ -55,7 +58,6 @@ object NaiveBayesModels extends Enumeration {
55
58
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
56
59
* Multinomial or Bernoulli
57
60
*/
58
-
59
61
class NaiveBayesModel private [mllib] (
60
62
val labels : Array [Double ],
61
63
val pi : Array [Double ],
@@ -68,11 +70,14 @@ class NaiveBayesModel private[mllib] (
68
70
private val brzPi = new BDV [Double ](pi)
69
71
private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
70
72
71
- private val brzNegTheta : Option [BDM [Double ]] = modelType match {
72
- case NaiveBayesModels .Multinomial => None
73
+ // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
74
+ // precomputing log(1.0 - exp(theta)) and its sum for linear algebra application
75
+ // of this condition in predict function
76
+ private val (brzNegTheta, brzNegThetaSum) = modelType match {
77
+ case NaiveBayesModels .Multinomial => (None , None )
73
78
case NaiveBayesModels .Bernoulli =>
74
79
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
75
- Option (negTheta)
80
+ ( Option (negTheta), Option (brzSum(brzNegTheta, Axis ._1)) )
76
81
}
77
82
78
83
override def predict (testData : RDD [Vector ]): RDD [Double ] = {
@@ -89,8 +94,7 @@ class NaiveBayesModel private[mllib] (
89
94
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
90
95
case NaiveBayesModels .Bernoulli =>
91
96
labels (brzArgmax (brzPi +
92
- (brzTheta - brzNegTheta.get) * testData.toBreeze +
93
- brzSum(brzNegTheta.get, Axis ._1)))
97
+ (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
94
98
}
95
99
}
96
100
@@ -114,10 +118,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
114
118
def thisClassName = " org.apache.spark.mllib.classification.NaiveBayesModel"
115
119
116
120
/** Model data for model import/export */
117
- case class Data (labels : Array [Double ],
118
- pi : Array [Double ],
119
- theta : Array [Array [Double ]],
120
- modelType : String )
121
+ case class Data (
122
+ labels : Array [Double ],
123
+ pi : Array [Double ],
124
+ theta : Array [Array [Double ]],
125
+ modelType : String )
121
126
122
127
def save (sc : SparkContext , path : String , data : Data ): Unit = {
123
128
val sqlContext = new SQLContext (sc)
@@ -192,7 +197,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
192
197
* Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
193
198
*/
194
199
class NaiveBayes private (private var lambda : Double ,
195
- var modelType : NaiveBayesModels ) extends Serializable with Logging {
200
+ private var modelType : NaiveBayesModels ) extends Serializable with Logging {
196
201
197
202
def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
198
203
@@ -284,7 +289,7 @@ object NaiveBayes {
284
289
/**
285
290
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
286
291
*
287
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
292
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
288
293
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
289
294
* document classification.
290
295
*
@@ -300,7 +305,7 @@ object NaiveBayes {
300
305
/**
301
306
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
302
307
*
303
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
308
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
304
309
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
305
310
* document classification.
306
311
*
@@ -316,11 +321,13 @@ object NaiveBayes {
316
321
/**
317
322
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
318
323
*
319
- * This is by default the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle
320
- * all kinds of discrete data. For example, by converting documents into TF-IDF vectors,
321
- * it can be used for document classification. By making every vector a 0-1 vector and
322
- * setting the model type to NaiveBayesModels.Bernoulli, it fits and predicts as
323
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
324
+ * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p ]])
325
+ * or Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The Multinomial NB can handle
326
+ * discrete count data and can be called by setting the model type to "Multinomial".
327
+ * For example, it can be used with word counts or TF_IDF vectors of documents.
328
+ * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
329
+ * 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as
330
+ * Bernoulli NB.
324
331
*
325
332
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
326
333
* vector or a count vector.
0 commit comments