Skip to content

Commit 13348e2

Browse files
mengxrjkbradley
authored andcommitted
[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire Author: Xiangrui Meng <[email protected]> Closes #6277 from mengxr/SPARK-7752 and squashes the following commits: f38b662 [Xiangrui Meng] add another case _ back in test ae5c66a [Xiangrui Meng] model type -> modelType 711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752 40ae53e [Xiangrui Meng] fix Java test suite 264a814 [Xiangrui Meng] add case _ back 3c456a8 [Xiangrui Meng] update NB user guide 17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings
1 parent a25c1ab commit 13348e2

File tree

4 files changed

+75
-59
lines changed

4 files changed

+75
-59
lines changed

docs/mllib-naive-bayes.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Within that context, each observation is a document and each
2121
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
2222
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
2323
Feature values must be nonnegative. The model type is selected with an optional parameter
24-
"Multinomial" or "Bernoulli" with "Multinomial" as the default.
24+
"multinomial" or "bernoulli" with "multinomial" as the default.
2525
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
2626
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
2727
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
@@ -35,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach
3535
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
3636
multinomial naive Bayes. It takes an RDD of
3737
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
38-
smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a
38+
smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a
3939
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
4040
can be used for evaluation and prediction.
4141

@@ -54,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
5454
val training = splits(0)
5555
val test = splits(1)
5656

57-
val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial")
57+
val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
5858

5959
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
6060
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
@@ -75,14 +75,15 @@ optionally smoothing parameter `lambda` as input, and output a
7575
can be used for evaluation and prediction.
7676

7777
{% highlight java %}
78+
import scala.Tuple2;
79+
7880
import org.apache.spark.api.java.JavaPairRDD;
7981
import org.apache.spark.api.java.JavaRDD;
8082
import org.apache.spark.api.java.function.Function;
8183
import org.apache.spark.api.java.function.PairFunction;
8284
import org.apache.spark.mllib.classification.NaiveBayes;
8385
import org.apache.spark.mllib.classification.NaiveBayesModel;
8486
import org.apache.spark.mllib.regression.LabeledPoint;
85-
import scala.Tuple2;
8687

8788
JavaRDD<LabeledPoint> training = ... // training set
8889
JavaRDD<LabeledPoint> test = ... // test set

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,20 @@ import org.json4s.JsonDSL._
2525
import org.json4s.jackson.JsonMethods._
2626

2727
import org.apache.spark.{Logging, SparkContext, SparkException}
28-
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
28+
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
2929
import org.apache.spark.mllib.regression.LabeledPoint
3030
import org.apache.spark.mllib.util.{Loader, Saveable}
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.{DataFrame, SQLContext}
3333

34-
3534
/**
3635
* Model for Naive Bayes Classifiers.
3736
*
3837
* @param labels list of labels
3938
* @param pi log of class priors, whose dimension is C, number of labels
4039
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
4140
* where D is number of features
42-
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
41+
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
4342
*/
4443
class NaiveBayesModel private[mllib] (
4544
val labels: Array[Double],
@@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
4847
val modelType: String)
4948
extends ClassificationModel with Serializable with Saveable {
5049

50+
import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
51+
5152
private val piVector = new DenseVector(pi)
52-
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
53+
private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true)
5354

5455
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
55-
this(labels, pi, theta, "Multinomial")
56+
this(labels, pi, theta, NaiveBayes.Multinomial)
5657

5758
/** A Java-friendly constructor that takes three Iterable parameters. */
5859
private[mllib] def this(
@@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
6162
theta: JIterable[JIterable[Double]]) =
6263
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
6364

65+
require(supportedModelTypes.contains(modelType),
66+
s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
67+
6468
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
6569
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6670
// application of this condition (in predict function).
6771
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
68-
case "Multinomial" => (None, None)
69-
case "Bernoulli" =>
72+
case Multinomial => (None, None)
73+
case Bernoulli =>
7074
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
7175
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
7276
val thetaMinusNegTheta = thetaMatrix.map { value =>
@@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
7579
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
7680
case _ =>
7781
// This should never happen.
78-
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
82+
throw new UnknownError(s"Invalid modelType: $modelType.")
7983
}
8084

8185
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
8892

8993
override def predict(testData: Vector): Double = {
9094
modelType match {
91-
case "Multinomial" =>
95+
case Multinomial =>
9296
val prob = thetaMatrix.multiply(testData)
9397
BLAS.axpy(1.0, piVector, prob)
9498
labels(prob.argmax)
95-
case "Bernoulli" =>
99+
case Bernoulli =>
96100
testData.foreachActive { (index, value) =>
97101
if (value != 0.0 && value != 1.0) {
98102
throw new SparkException(
99-
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
103+
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
100104
}
101105
}
102106
val prob = thetaMinusNegTheta.get.multiply(testData)
@@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
105109
labels(prob.argmax)
106110
case _ =>
107111
// This should never happen.
108-
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
112+
throw new UnknownError(s"Invalid modelType: $modelType.")
109113
}
110114
}
111115

@@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
230234
s"($loadedClassName, $version). Supported:\n" +
231235
s" ($classNameV1_0, 1.0)")
232236
}
233-
assert(model.pi.size == numClasses,
237+
assert(model.pi.length == numClasses,
234238
s"NaiveBayesModel.load expected $numClasses classes," +
235-
s" but class priors vector pi had ${model.pi.size} elements")
236-
assert(model.theta.size == numClasses,
239+
s" but class priors vector pi had ${model.pi.length} elements")
240+
assert(model.theta.length == numClasses,
237241
s"NaiveBayesModel.load expected $numClasses classes," +
238-
s" but class conditionals array theta had ${model.theta.size} elements")
239-
assert(model.theta.forall(_.size == numFeatures),
242+
s" but class conditionals array theta had ${model.theta.length} elements")
243+
assert(model.theta.forall(_.length == numFeatures),
240244
s"NaiveBayesModel.load expected $numFeatures features," +
241245
s" but class conditionals array theta had elements of size:" +
242-
s" ${model.theta.map(_.size).mkString(",")}")
246+
s" ${model.theta.map(_.length).mkString(",")}")
243247
model
244248
}
245249
}
@@ -257,9 +261,11 @@ class NaiveBayes private (
257261
private var lambda: Double,
258262
private var modelType: String) extends Serializable with Logging {
259263

260-
def this(lambda: Double) = this(lambda, "Multinomial")
264+
import NaiveBayes.{Bernoulli, Multinomial}
261265

262-
def this() = this(1.0, "Multinomial")
266+
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
267+
268+
def this() = this(1.0, NaiveBayes.Multinomial)
263269

264270
/** Set the smoothing parameter. Default: 1.0. */
265271
def setLambda(lambda: Double): NaiveBayes = {
@@ -272,12 +278,11 @@ class NaiveBayes private (
272278

273279
/**
274280
* Set the model type using a string (case-sensitive).
275-
* Supported options: "Multinomial" and "Bernoulli".
276-
* (default: Multinomial)
281+
* Supported options: "multinomial" (default) and "bernoulli".
277282
*/
278-
def setModelType(modelType:String): NaiveBayes = {
283+
def setModelType(modelType: String): NaiveBayes = {
279284
require(NaiveBayes.supportedModelTypes.contains(modelType),
280-
s"NaiveBayes was created with an unknown ModelType: $modelType")
285+
s"NaiveBayes was created with an unknown modelType: $modelType.")
281286
this.modelType = modelType
282287
this
283288
}
@@ -308,7 +313,7 @@ class NaiveBayes private (
308313
}
309314
if (!values.forall(v => v == 0.0 || v == 1.0)) {
310315
throw new SparkException(
311-
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
316+
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
312317
}
313318
}
314319

@@ -317,7 +322,7 @@ class NaiveBayes private (
317322
// TODO: similar to reduceByKeyLocally to save one stage.
318323
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
319324
createCombiner = (v: Vector) => {
320-
if (modelType == "Bernoulli") {
325+
if (modelType == Bernoulli) {
321326
requireZeroOneBernoulliValues(v)
322327
} else {
323328
requireNonnegativeValues(v)
@@ -352,11 +357,11 @@ class NaiveBayes private (
352357
labels(i) = label
353358
pi(i) = math.log(n + lambda) - piLogDenom
354359
val thetaLogDenom = modelType match {
355-
case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
356-
case "Bernoulli" => math.log(n + 2.0 * lambda)
360+
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
361+
case Bernoulli => math.log(n + 2.0 * lambda)
357362
case _ =>
358363
// This should never happen.
359-
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
364+
throw new UnknownError(s"Invalid modelType: $modelType.")
360365
}
361366
var j = 0
362367
while (j < numFeatures) {
@@ -375,8 +380,14 @@ class NaiveBayes private (
375380
*/
376381
object NaiveBayes {
377382

383+
/** String name for multinomial model type. */
384+
private[classification] val Multinomial: String = "multinomial"
385+
386+
/** String name for Bernoulli model type. */
387+
private[classification] val Bernoulli: String = "bernoulli"
388+
378389
/* Set of modelTypes that NaiveBayes supports */
379-
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
390+
private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
380391

381392
/**
382393
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@@ -406,7 +417,7 @@ object NaiveBayes {
406417
* @param lambda The smoothing parameter
407418
*/
408419
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
409-
new NaiveBayes(lambda, "Multinomial").run(input)
420+
new NaiveBayes(lambda, Multinomial).run(input)
410421
}
411422

412423
/**
@@ -429,7 +440,7 @@ object NaiveBayes {
429440
*/
430441
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
431442
require(supportedModelTypes.contains(modelType),
432-
s"NaiveBayes was created with an unknown ModelType: $modelType")
443+
s"NaiveBayes was created with an unknown modelType: $modelType.")
433444
new NaiveBayes(lambda, modelType).run(input)
434445
}
435446

mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception {
108108
@Test
109109
public void testModelTypeSetters() {
110110
NaiveBayes nb = new NaiveBayes()
111-
.setModelType("Bernoulli")
112-
.setModelType("Multinomial");
111+
.setModelType("bernoulli")
112+
.setModelType("multinomial");
113113
}
114114
}

0 commit comments

Comments
 (0)