@@ -25,7 +25,6 @@ import scala.util.Random
25
25
import org .scalatest .FunSuite
26
26
27
27
import org .apache .spark .SparkException
28
- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
29
28
import org .apache .spark .mllib .linalg .Vectors
30
29
import org .apache .spark .mllib .regression .LabeledPoint
31
30
import org .apache .spark .mllib .util .{LocalClusterSparkContext , MLlibTestSparkContext }
@@ -49,7 +48,7 @@ object NaiveBayesSuite {
49
48
theta : Array [Array [Double ]], // CXD
50
49
nPoints : Int ,
51
50
seed : Int ,
52
- dataModel : NaiveBayesModels = NaiveBayesModels .Multinomial ,
51
+ dataModel : NaiveBayes . ModelType = NaiveBayes .Multinomial ,
53
52
sample : Int = 10 ): Seq [LabeledPoint ] = {
54
53
val D = theta(0 ).length
55
54
val rnd = new Random (seed)
@@ -60,10 +59,10 @@ object NaiveBayesSuite {
60
59
for (i <- 0 until nPoints) yield {
61
60
val y = calcLabel(rnd.nextDouble(), _pi)
62
61
val xi = dataModel match {
63
- case NaiveBayesModels .Bernoulli => Array .tabulate[Double ] (D ) {j =>
62
+ case NaiveBayes .Bernoulli => Array .tabulate[Double ] (D ) {j =>
64
63
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
65
64
}
66
- case NaiveBayesModels .Multinomial =>
65
+ case NaiveBayes .Multinomial =>
67
66
val mult = Multinomial (BDV (_theta(y)))
68
67
val emptyMap = (0 until D ).map(x => (x, 0.0 )).toMap
69
68
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -78,7 +77,7 @@ object NaiveBayesSuite {
78
77
79
78
/** Binary labels, 3 features */
80
79
private val binaryModel = new NaiveBayesModel (labels = Array (0.0 , 1.0 ), pi = Array (0.2 , 0.8 ),
81
- theta = Array (Array (0.1 , 0.3 , 0.6 ), Array (0.2 , 0.4 , 0.4 )), NaiveBayesModels .Bernoulli )
80
+ theta = Array (Array (0.1 , 0.3 , 0.6 ), Array (0.2 , 0.4 , 0.4 )), NaiveBayes .Bernoulli )
82
81
}
83
82
84
83
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -121,7 +120,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
121
120
).map(_.map(math.log))
122
121
123
122
val testData = NaiveBayesSuite .generateNaiveBayesInput(
124
- pi, theta, nPoints, 42 , NaiveBayesModels .Multinomial )
123
+ pi, theta, nPoints, 42 , NaiveBayes .Multinomial )
125
124
val testRDD = sc.parallelize(testData, 2 )
126
125
testRDD.cache()
127
126
@@ -133,7 +132,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
133
132
theta,
134
133
nPoints,
135
134
17 ,
136
- NaiveBayesModels .Multinomial )
135
+ NaiveBayes .Multinomial )
137
136
val validationRDD = sc.parallelize(validationData, 2 )
138
137
139
138
// Test prediction on RDD.
@@ -158,7 +157,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
158
157
theta,
159
158
nPoints,
160
159
45 ,
161
- NaiveBayesModels .Bernoulli )
160
+ NaiveBayes .Bernoulli )
162
161
val testRDD = sc.parallelize(testData, 2 )
163
162
testRDD.cache()
164
163
@@ -170,7 +169,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
170
169
theta,
171
170
nPoints,
172
171
20 ,
173
- NaiveBayesModels .Bernoulli )
172
+ NaiveBayes .Bernoulli )
174
173
val validationRDD = sc.parallelize(validationData, 2 )
175
174
176
175
// Test prediction on RDD.
0 commit comments