@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
46
46
val labels : Array [Double ],
47
47
val pi : Array [Double ],
48
48
val theta : Array [Array [Double ]],
49
- val modelType : NaiveBayes . ModelType )
49
+ val modelType : String )
50
50
extends ClassificationModel with Serializable with Saveable {
51
51
52
52
def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
53
- this (labels, pi, theta, NaiveBayes .Multinomial )
53
+ this (labels, pi, theta, NaiveBayes .Multinomial .toString )
54
54
55
55
private val brzPi = new BDV [Double ](pi)
56
56
private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
57
57
58
58
// Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
59
59
// this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
60
60
// of this condition in predict function
61
- private val (brzNegTheta, brzNegThetaSum) = modelType match {
61
+ private val (brzNegTheta, brzNegThetaSum) = NaiveBayes . ModelType .fromString( modelType) match {
62
62
case NaiveBayes .Multinomial => (None , None )
63
63
case NaiveBayes .Bernoulli =>
64
64
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
74
74
}
75
75
76
76
override def predict (testData : Vector ): Double = {
77
- modelType match {
77
+ NaiveBayes . ModelType .fromString( modelType) match {
78
78
case NaiveBayes .Multinomial =>
79
79
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
80
80
case NaiveBayes .Bernoulli =>
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
84
84
}
85
85
86
86
override def save (sc : SparkContext , path : String ): Unit = {
87
- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
87
+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
88
88
NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
89
89
}
90
90
@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137
137
val labels = data.getAs[Seq [Double ]](0 ).toArray
138
138
val pi = data.getAs[Seq [Double ]](1 ).toArray
139
139
val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
140
- val modelType = NaiveBayes .ModelType .fromString(data.getString(3 ))
140
+ val modelType = NaiveBayes .ModelType .fromString(data.getString(3 )).toString
141
141
new NaiveBayesModel (labels, pi, theta, modelType)
142
142
}
143
143
}
144
144
145
145
override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
146
- def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
146
+ def getModelType (metadata : JValue ): String = {
147
147
implicit val formats = DefaultFormats
148
- NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ])
148
+ NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ]).toString
149
149
}
150
150
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
151
151
val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -265,7 +265,7 @@ class NaiveBayes private (
265
265
i += 1
266
266
}
267
267
268
- new NaiveBayesModel (labels, pi, theta, modelType)
268
+ new NaiveBayesModel (labels, pi, theta, modelType.toString )
269
269
}
270
270
}
271
271
0 commit comments