Skip to content

Commit e2d925e

Browse files
committed
fixed nonserializable error that was causing naivebayes test failures
1 parent 2d0c1ba commit e2d925e

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
4646
val labels: Array[Double],
4747
val pi: Array[Double],
4848
val theta: Array[Array[Double]],
49-
val modelType: NaiveBayes.ModelType)
49+
val modelType: String)
5050
extends ClassificationModel with Serializable with Saveable {
5151

5252
def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
53-
this(labels, pi, theta, NaiveBayes.Multinomial)
53+
this(labels, pi, theta, NaiveBayes.Multinomial.toString)
5454

5555
private val brzPi = new BDV[Double](pi)
5656
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
5757

5858
// Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
5959
// this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
6060
// of this condition in predict function
61-
private val (brzNegTheta, brzNegThetaSum) = modelType match {
61+
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match {
6262
case NaiveBayes.Multinomial => (None, None)
6363
case NaiveBayes.Bernoulli =>
6464
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
7474
}
7575

7676
override def predict(testData: Vector): Double = {
77-
modelType match {
77+
NaiveBayes.ModelType.fromString(modelType) match {
7878
case NaiveBayes.Multinomial =>
7979
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
8080
case NaiveBayes.Bernoulli =>
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
8484
}
8585

8686
override def save(sc: SparkContext, path: String): Unit = {
87-
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
87+
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
8888
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
8989
}
9090

@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137137
val labels = data.getAs[Seq[Double]](0).toArray
138138
val pi = data.getAs[Seq[Double]](1).toArray
139139
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
140-
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
140+
val modelType = NaiveBayes.ModelType.fromString(data.getString(3)).toString
141141
new NaiveBayesModel(labels, pi, theta, modelType)
142142
}
143143
}
144144

145145
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
146-
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
146+
def getModelType(metadata: JValue): String = {
147147
implicit val formats = DefaultFormats
148-
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
148+
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]).toString
149149
}
150150
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
151151
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -265,7 +265,7 @@ class NaiveBayes private (
265265
i += 1
266266
}
267267

268-
new NaiveBayesModel(labels, pi, theta, modelType)
268+
new NaiveBayesModel(labels, pi, theta, modelType.toString)
269269
}
270270
}
271271

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ object NaiveBayesSuite {
5252
sample: Int = 10): Seq[LabeledPoint] = {
5353
val D = theta(0).length
5454
val rnd = new Random(seed)
55-
55+
c
5656
val _pi = pi.map(math.pow(math.E, _))
5757
val _theta = theta.map(row => row.map(math.pow(math.E, _)))
5858

@@ -77,7 +77,7 @@ object NaiveBayesSuite {
7777

7878
/** Binary labels, 3 features */
7979
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
80-
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
80+
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli.toString)
8181
}
8282

8383
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -111,7 +111,6 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
111111

112112
test("Naive Bayes Multinomial") {
113113
val nPoints = 1000
114-
115114
val pi = Array(0.5, 0.1, 0.4).map(math.log)
116115
val theta = Array(
117116
Array(0.70, 0.10, 0.10, 0.10), // label 0
@@ -120,7 +119,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
120119
).map(_.map(math.log))
121120

122121
val testData = NaiveBayesSuite.generateNaiveBayesInput(
123-
pi, theta, nPoints, 42, NaiveBayes.Multinomial)
122+
pi,
123+
theta,
124+
nPoints,
125+
42,
126+
NaiveBayes.Multinomial)
124127
val testRDD = sc.parallelize(testData, 2)
125128
testRDD.cache()
126129

@@ -144,7 +147,6 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
144147

145148
test("Naive Bayes Bernoulli") {
146149
val nPoints = 10000
147-
148150
val pi = Array(0.5, 0.3, 0.2).map(math.log)
149151
val theta = Array(
150152
Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0

0 commit comments

Comments
 (0)