@@ -25,21 +25,20 @@ import org.json4s.JsonDSL._
25
25
import org .json4s .jackson .JsonMethods ._
26
26
27
27
import org .apache .spark .{Logging , SparkContext , SparkException }
28
- import org .apache .spark .mllib .linalg .{BLAS , DenseMatrix , DenseVector , SparseVector , Vector , Vectors }
28
+ import org .apache .spark .mllib .linalg .{BLAS , DenseMatrix , DenseVector , SparseVector , Vector }
29
29
import org .apache .spark .mllib .regression .LabeledPoint
30
30
import org .apache .spark .mllib .util .{Loader , Saveable }
31
31
import org .apache .spark .rdd .RDD
32
32
import org .apache .spark .sql .{DataFrame , SQLContext }
33
33
34
-
35
34
/**
36
35
* Model for Naive Bayes Classifiers.
37
36
*
38
37
* @param labels list of labels
39
38
* @param pi log of class priors, whose dimension is C, number of labels
40
39
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
41
40
* where D is number of features
42
- * @param modelType The type of NB model to fit can be "Multinomial " or "Bernoulli "
41
+ * @param modelType The type of NB model to fit can be "multinomial " or "bernoulli "
43
42
*/
44
43
class NaiveBayesModel private [mllib] (
45
44
val labels : Array [Double ],
@@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
48
47
val modelType : String )
49
48
extends ClassificationModel with Serializable with Saveable {
50
49
50
+ import NaiveBayes .{Bernoulli , Multinomial , supportedModelTypes }
51
+
51
52
private val piVector = new DenseVector (pi)
52
- private val thetaMatrix = new DenseMatrix (labels.size , theta(0 ).size , theta.flatten, true )
53
+ private val thetaMatrix = new DenseMatrix (labels.length , theta(0 ).length , theta.flatten, true )
53
54
54
55
private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
55
- this (labels, pi, theta, " Multinomial" )
56
+ this (labels, pi, theta, NaiveBayes . Multinomial )
56
57
57
58
/** A Java-friendly constructor that takes three Iterable parameters. */
58
59
private [mllib] def this (
@@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
61
62
theta : JIterable [JIterable [Double ]]) =
62
63
this (labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
63
64
65
+ require(supportedModelTypes.contains(modelType),
66
+ s " Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes. " )
67
+
64
68
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
65
69
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
66
70
// application of this condition (in predict function).
67
71
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
68
- case " Multinomial" => (None , None )
69
- case " Bernoulli" =>
72
+ case Multinomial => (None , None )
73
+ case Bernoulli =>
70
74
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
71
75
val ones = new DenseVector (Array .fill(thetaMatrix.numCols){1.0 })
72
76
val thetaMinusNegTheta = thetaMatrix.map { value =>
@@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
75
79
(Option (thetaMinusNegTheta), Option (negTheta.multiply(ones)))
76
80
case _ =>
77
81
// This should never happen.
78
- throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType : $modelType" )
82
+ throw new UnknownError (s " Invalid modelType : $modelType. " )
79
83
}
80
84
81
85
override def predict (testData : RDD [Vector ]): RDD [Double ] = {
@@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
88
92
89
93
override def predict (testData : Vector ): Double = {
90
94
modelType match {
91
- case " Multinomial" =>
95
+ case Multinomial =>
92
96
val prob = thetaMatrix.multiply(testData)
93
97
BLAS .axpy(1.0 , piVector, prob)
94
98
labels(prob.argmax)
95
- case " Bernoulli" =>
99
+ case Bernoulli =>
96
100
testData.foreachActive { (index, value) =>
97
101
if (value != 0.0 && value != 1.0 ) {
98
102
throw new SparkException (
99
- s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData. " )
103
+ s " Bernoulli naive Bayes requires 0 or 1 feature values but found $testData. " )
100
104
}
101
105
}
102
106
val prob = thetaMinusNegTheta.get.multiply(testData)
@@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
105
109
labels(prob.argmax)
106
110
case _ =>
107
111
// This should never happen.
108
- throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType : $modelType" )
112
+ throw new UnknownError (s " Invalid modelType : $modelType. " )
109
113
}
110
114
}
111
115
@@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
230
234
s " ( $loadedClassName, $version). Supported: \n " +
231
235
s " ( $classNameV1_0, 1.0) " )
232
236
}
233
- assert(model.pi.size == numClasses,
237
+ assert(model.pi.length == numClasses,
234
238
s " NaiveBayesModel.load expected $numClasses classes, " +
235
- s " but class priors vector pi had ${model.pi.size } elements " )
236
- assert(model.theta.size == numClasses,
239
+ s " but class priors vector pi had ${model.pi.length } elements " )
240
+ assert(model.theta.length == numClasses,
237
241
s " NaiveBayesModel.load expected $numClasses classes, " +
238
- s " but class conditionals array theta had ${model.theta.size } elements " )
239
- assert(model.theta.forall(_.size == numFeatures),
242
+ s " but class conditionals array theta had ${model.theta.length } elements " )
243
+ assert(model.theta.forall(_.length == numFeatures),
240
244
s " NaiveBayesModel.load expected $numFeatures features, " +
241
245
s " but class conditionals array theta had elements of size: " +
242
- s " ${model.theta.map(_.size ).mkString(" ," )}" )
246
+ s " ${model.theta.map(_.length ).mkString(" ," )}" )
243
247
model
244
248
}
245
249
}
@@ -257,9 +261,11 @@ class NaiveBayes private (
257
261
private var lambda : Double ,
258
262
private var modelType : String ) extends Serializable with Logging {
259
263
260
- def this ( lambda : Double ) = this (lambda, " Multinomial" )
264
+ import NaiveBayes .{ Bernoulli , Multinomial }
261
265
262
- def this () = this (1.0 , " Multinomial" )
266
+ def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
267
+
268
+ def this () = this (1.0 , NaiveBayes .Multinomial )
263
269
264
270
/** Set the smoothing parameter. Default: 1.0. */
265
271
def setLambda (lambda : Double ): NaiveBayes = {
@@ -272,12 +278,11 @@ class NaiveBayes private (
272
278
273
279
/**
274
280
* Set the model type using a string (case-sensitive).
275
- * Supported options: "Multinomial" and "Bernoulli".
276
- * (default: Multinomial)
281
+ * Supported options: "multinomial" (default) and "bernoulli".
277
282
*/
278
- def setModelType (modelType: String ): NaiveBayes = {
283
+ def setModelType (modelType : String ): NaiveBayes = {
279
284
require(NaiveBayes .supportedModelTypes.contains(modelType),
280
- s " NaiveBayes was created with an unknown ModelType : $modelType" )
285
+ s " NaiveBayes was created with an unknown modelType : $modelType. " )
281
286
this .modelType = modelType
282
287
this
283
288
}
@@ -308,7 +313,7 @@ class NaiveBayes private (
308
313
}
309
314
if (! values.forall(v => v == 0.0 || v == 1.0 )) {
310
315
throw new SparkException (
311
- s " Bernoulli Naive Bayes requires 0 or 1 feature values but found $v. " )
316
+ s " Bernoulli naive Bayes requires 0 or 1 feature values but found $v. " )
312
317
}
313
318
}
314
319
@@ -317,7 +322,7 @@ class NaiveBayes private (
317
322
// TODO: similar to reduceByKeyLocally to save one stage.
318
323
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long , DenseVector )](
319
324
createCombiner = (v : Vector ) => {
320
- if (modelType == " Bernoulli" ) {
325
+ if (modelType == Bernoulli ) {
321
326
requireZeroOneBernoulliValues(v)
322
327
} else {
323
328
requireNonnegativeValues(v)
@@ -352,11 +357,11 @@ class NaiveBayes private (
352
357
labels(i) = label
353
358
pi(i) = math.log(n + lambda) - piLogDenom
354
359
val thetaLogDenom = modelType match {
355
- case " Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
356
- case " Bernoulli" => math.log(n + 2.0 * lambda)
360
+ case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
361
+ case Bernoulli => math.log(n + 2.0 * lambda)
357
362
case _ =>
358
363
// This should never happen.
359
- throw new UnknownError (s " NaiveBayes was created with an unknown ModelType : $modelType" )
364
+ throw new UnknownError (s " Invalid modelType : $modelType. " )
360
365
}
361
366
var j = 0
362
367
while (j < numFeatures) {
@@ -375,8 +380,14 @@ class NaiveBayes private (
375
380
*/
376
381
object NaiveBayes {
377
382
383
+ /** String name for multinomial model type. */
384
+ private [classification] val Multinomial : String = " multinomial"
385
+
386
+ /** String name for Bernoulli model type. */
387
+ private [classification] val Bernoulli : String = " bernoulli"
388
+
378
389
/* Set of modelTypes that NaiveBayes supports */
379
- private [mllib ] val supportedModelTypes = Set (" Multinomial" , " Bernoulli" )
390
+ private [classification ] val supportedModelTypes = Set (Multinomial , Bernoulli )
380
391
381
392
/**
382
393
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@@ -406,7 +417,7 @@ object NaiveBayes {
406
417
* @param lambda The smoothing parameter
407
418
*/
408
419
def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
409
- new NaiveBayes (lambda, " Multinomial" ).run(input)
420
+ new NaiveBayes (lambda, Multinomial ).run(input)
410
421
}
411
422
412
423
/**
@@ -429,7 +440,7 @@ object NaiveBayes {
429
440
*/
430
441
def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
431
442
require(supportedModelTypes.contains(modelType),
432
- s " NaiveBayes was created with an unknown ModelType : $modelType" )
443
+ s " NaiveBayes was created with an unknown modelType : $modelType. " )
433
444
new NaiveBayes (lambda, modelType).run(input)
434
445
}
435
446
0 commit comments