Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit e016569

Browse files
committed
updated test suite with model type fix
1 parent 85f298f commit e016569

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import scala.util.Random
2525
import org.scalatest.FunSuite
2626

2727
import org.apache.spark.SparkException
28-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2928
import org.apache.spark.mllib.linalg.Vectors
3029
import org.apache.spark.mllib.regression.LabeledPoint
3130
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -49,7 +48,7 @@ object NaiveBayesSuite {
4948
theta: Array[Array[Double]], // CXD
5049
nPoints: Int,
5150
seed: Int,
52-
dataModel: NaiveBayesModels= NaiveBayesModels.Multinomial,
51+
dataModel: NaiveBayes.ModelType = NaiveBayes.Multinomial,
5352
sample: Int = 10): Seq[LabeledPoint] = {
5453
val D = theta(0).length
5554
val rnd = new Random(seed)
@@ -60,10 +59,10 @@ object NaiveBayesSuite {
6059
for (i <- 0 until nPoints) yield {
6160
val y = calcLabel(rnd.nextDouble(), _pi)
6261
val xi = dataModel match {
63-
case NaiveBayesModels.Bernoulli => Array.tabulate[Double] (D) {j =>
62+
case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) {j =>
6463
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
6564
}
66-
case NaiveBayesModels.Multinomial =>
65+
case NaiveBayes.Multinomial =>
6766
val mult = Multinomial(BDV(_theta(y)))
6867
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
6968
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -78,7 +77,7 @@ object NaiveBayesSuite {
7877

7978
/** Binary labels, 3 features */
8079
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
81-
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayesModels.Bernoulli)
80+
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
8281
}
8382

8483
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -121,7 +120,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
121120
).map(_.map(math.log))
122121

123122
val testData = NaiveBayesSuite.generateNaiveBayesInput(
124-
pi, theta, nPoints, 42, NaiveBayesModels.Multinomial)
123+
pi, theta, nPoints, 42, NaiveBayes.Multinomial)
125124
val testRDD = sc.parallelize(testData, 2)
126125
testRDD.cache()
127126

@@ -133,7 +132,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
133132
theta,
134133
nPoints,
135134
17,
136-
NaiveBayesModels.Multinomial)
135+
NaiveBayes.Multinomial)
137136
val validationRDD = sc.parallelize(validationData, 2)
138137

139138
// Test prediction on RDD.
@@ -158,7 +157,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
158157
theta,
159158
nPoints,
160159
45,
161-
NaiveBayesModels.Bernoulli)
160+
NaiveBayes.Bernoulli)
162161
val testRDD = sc.parallelize(testData, 2)
163162
testRDD.cache()
164163

@@ -170,7 +169,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
170169
theta,
171170
nPoints,
172171
20,
173-
NaiveBayesModels.Bernoulli)
172+
NaiveBayes.Bernoulli)
174173
val validationRDD = sc.parallelize(validationData, 2)
175174

176175
// Test prediction on RDD.

0 commit comments

Comments
 (0)