Skip to content

Commit 900b586

Browse files
committed
fixed model call so that uses type argument
1 parent ea09b28 commit 900b586

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,21 @@ object NaiveBayes {
310310
*
311311
* The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]])
312312
* or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle
313-
* discrete count data and can be called by setting the model type to "Multinomial".
313+
* discrete count data and can be called by setting the model type to "multinomial".
314314
* For example, it can be used with word counts or TF_IDF vectors of documents.
315315
* The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
316-
* 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as
316+
* 0-1 vector and setting the model type to "bernoulli", the fits and predicts as
317317
* Bernoulli NB.
318318
*
319319
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
320320
* vector or a count vector.
321321
* @param lambda The smoothing parameter
322322
*
323323
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
324-
* Multinomial or Bernoulli
324+
* multinomial or bernoulli
325325
*/
326326
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
327-
new NaiveBayes(lambda, Multinomial).run(input)
327+
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
328328
}
329329

330330

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
124124
val testRDD = sc.parallelize(testData, 2)
125125
testRDD.cache()
126126

127-
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
127+
val model = NaiveBayes.train(testRDD, 1.0, "multinomial")
128128
validateModelFit(pi, theta, model)
129129

130130
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
@@ -161,7 +161,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
161161
val testRDD = sc.parallelize(testData, 2)
162162
testRDD.cache()
163163

164-
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
164+
val model = NaiveBayes.train(testRDD, 1.0, "bernoulli")
165165
validateModelFit(pi, theta, model)
166166

167167
val validationData = NaiveBayesSuite.generateNaiveBayesInput(

0 commit comments

Comments
 (0)