@@ -27,24 +27,11 @@ import org.json4s.{DefaultFormats, JValue}
27
27
import org .apache .spark .{Logging , SparkContext , SparkException }
28
28
import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
29
29
import org .apache .spark .mllib .regression .LabeledPoint
30
- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
31
30
import org .apache .spark .mllib .util .{Loader , Saveable }
32
31
import org .apache .spark .rdd .RDD
33
32
import org .apache .spark .sql .{DataFrame , SQLContext }
34
33
35
34
36
- /**
37
- *
38
- */
39
- object NaiveBayesModels extends Enumeration {
40
- type NaiveBayesModels = Value
41
- val Multinomial, Bernoulli = Value
42
-
43
- implicit def toString (model : NaiveBayesModels ): String = {
44
- model.toString
45
- }
46
- }
47
-
48
35
/**
49
36
* Model for Naive Bayes Classifiers.
50
37
*
@@ -60,17 +47,18 @@ class NaiveBayesModel private[mllib] (
60
47
val labels : Array [Double ],
61
48
val pi : Array [Double ],
62
49
val theta : Array [Array [Double ]],
63
- val modelType : NaiveBayesModels ) extends ClassificationModel with Serializable with Saveable {
50
+ val modelType : NaiveBayes .ModelType )
51
+ extends ClassificationModel with Serializable with Saveable {
64
52
65
53
def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
66
- this (labels, pi, theta, NaiveBayesModels .Multinomial )
54
+ this (labels, pi, theta, NaiveBayes .Multinomial )
67
55
68
56
private val brzPi = new BDV [Double ](pi)
69
57
private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
70
58
71
59
private val brzNegTheta : Option [BDM [Double ]] = modelType match {
72
- case NaiveBayesModels .Multinomial => None
73
- case NaiveBayesModels .Bernoulli =>
60
+ case NaiveBayes .Multinomial => None
61
+ case NaiveBayes .Bernoulli =>
74
62
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
75
63
Option (negTheta)
76
64
}
@@ -85,17 +73,17 @@ class NaiveBayesModel private[mllib] (
85
73
86
74
override def predict (testData : Vector ): Double = {
87
75
modelType match {
88
- case NaiveBayesModels .Multinomial =>
76
+ case NaiveBayes .Multinomial =>
89
77
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
90
- case NaiveBayesModels .Bernoulli =>
78
+ case NaiveBayes .Bernoulli =>
91
79
labels (brzArgmax (brzPi +
92
80
(brzTheta - brzNegTheta.get) * testData.toBreeze +
93
81
brzSum(brzNegTheta.get, Axis ._1)))
94
82
}
95
83
}
96
84
97
85
override def save (sc : SparkContext , path : String ): Unit = {
98
- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
86
+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
99
87
NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
100
88
}
101
89
@@ -147,15 +135,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
147
135
val labels = data.getAs[Seq [Double ]](0 ).toArray
148
136
val pi = data.getAs[Seq [Double ]](1 ).toArray
149
137
val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
150
- val modelType : NaiveBayesModels = NaiveBayesModels .withName (data.getAs[ String ] (3 ))
138
+ val modelType = NaiveBayes . ModelType .fromString (data.getString (3 ))
151
139
new NaiveBayesModel (labels, pi, theta, modelType)
152
140
}
153
141
}
154
142
155
143
override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
156
- def getModelType (metadata : JValue ): NaiveBayesModels = {
144
+ def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
157
145
implicit val formats = DefaultFormats
158
- NaiveBayesModels .withName ((metadata \ " modelType" ).extract[String ])
146
+ NaiveBayes . ModelType .fromString ((metadata \ " modelType" ).extract[String ])
159
147
}
160
148
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
161
149
val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -191,12 +179,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
191
179
* document classification. By making every vector a 0-1 vector, it can also be used as
192
180
* Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
193
181
*/
194
- class NaiveBayes private (private var lambda : Double ,
195
- var modelType : NaiveBayesModels ) extends Serializable with Logging {
182
+ class NaiveBayes private (
183
+ private var lambda : Double ,
184
+ var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
196
185
197
- def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
186
+ def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
198
187
199
- def this () = this (1.0 , NaiveBayesModels .Multinomial )
188
+ def this () = this (1.0 , NaiveBayes .Multinomial )
200
189
201
190
/** Set the smoothing parameter. Default: 1.0. */
202
191
def setLambda (lambda : Double ): NaiveBayes = {
@@ -205,7 +194,7 @@ class NaiveBayes private (private var lambda: Double,
205
194
}
206
195
207
196
/** Set the model type. Default: Multinomial. */
208
- def setModelType (model : NaiveBayesModels ): NaiveBayes = {
197
+ def setModelType (model : NaiveBayes . ModelType ): NaiveBayes = {
209
198
this .modelType = model
210
199
this
211
200
}
@@ -262,8 +251,8 @@ class NaiveBayes private (private var lambda: Double,
262
251
labels(i) = label
263
252
pi(i) = math.log(n + lambda) - piLogDenom
264
253
val thetaLogDenom = modelType match {
265
- case NaiveBayesModels .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
266
- case NaiveBayesModels .Bernoulli => math.log(n + 2.0 * lambda)
254
+ case NaiveBayes .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
255
+ case NaiveBayes .Bernoulli => math.log(n + 2.0 * lambda)
267
256
}
268
257
var j = 0
269
258
while (j < numFeatures) {
@@ -330,6 +319,32 @@ object NaiveBayes {
330
319
* Multinomial or Bernoulli
331
320
*/
332
321
def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
333
- new NaiveBayes (lambda, NaiveBayesModels .withName(modelType) ).run(input)
322
+ new NaiveBayes (lambda, Multinomial ).run(input)
334
323
}
324
+
325
+ sealed abstract class ModelType
326
+
327
+ object MODELTYPE {
328
+ final val MULTINOMIAL_STRING = " multinomial"
329
+ final val BERNOULLI_STRING = " bernoulli"
330
+
331
+ def fromString (modelType : String ): ModelType = modelType match {
332
+ case MULTINOMIAL_STRING => Multinomial
333
+ case BERNOULLI_STRING => Bernoulli
334
+ case _ =>
335
+ throw new IllegalArgumentException (s " Cannot recognize NaiveBayes ModelType: $modelType" )
336
+ }
337
+ }
338
+
339
+ final val ModelType = MODELTYPE
340
+
341
+ final val Multinomial : ModelType = new ModelType {
342
+ override def toString : String = ModelType .MULTINOMIAL_STRING
343
+ }
344
+
345
+ final val Bernoulli : ModelType = new ModelType {
346
+ override def toString : String = ModelType .BERNOULLI_STRING
347
+ }
348
+
335
349
}
350
+
0 commit comments