Skip to content

Commit 3730572

Browse files
committed
modified NB model type to be more Java-friendly
1 parent b61b5e2 commit 3730572

File tree

3 files changed

+117
-31
lines changed

3 files changed

+117
-31
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
import org.apache.spark.SparkConf;
21+
import org.apache.spark.api.java.JavaRDD;
22+
import org.apache.spark.api.java.JavaSparkContext;
23+
import org.apache.spark.api.java.function.Function;
24+
import org.apache.spark.mllib.classification.NaiveBayes;
25+
import org.apache.spark.mllib.linalg.Vectors;
26+
import org.apache.spark.mllib.regression.LabeledPoint;
27+
28+
import java.util.regex.Pattern;
29+
30+
public final class JavaNaiveBayes {
31+
32+
static class ParsePoint implements Function<String, LabeledPoint> {
33+
private static final Pattern COMMA = Pattern.compile(",");
34+
private static final Pattern SPACE = Pattern.compile(" ");
35+
36+
@Override
37+
public LabeledPoint call(String line) {
38+
String[] parts = COMMA.split(line);
39+
double y = Double.parseDouble(parts[0]);
40+
String[] tok = SPACE.split(parts[1]);
41+
double[] x = new double[tok.length];
42+
for (int i = 0; i < tok.length; ++i) {
43+
x[i] = Double.parseDouble(tok[i]);
44+
}
45+
return new LabeledPoint(y, Vectors.dense(x));
46+
}
47+
}
48+
49+
public static void main(String[] args) {
50+
if (args.length != 3) {
51+
System.err.println("Usage: JavaLR <input_dir> <step_size> <niters>");
52+
System.exit(1);
53+
}
54+
SparkConf sparkConf = new SparkConf().setAppName("JavaLR");
55+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
56+
JavaRDD<String> lines = sc.textFile(args[0]);
57+
JavaRDD<LabeledPoint> points = lines.map(new ParsePoint()).cache();
58+
double stepSize = Double.parseDouble(args[1]);
59+
int iterations = Integer.parseInt(args[2]);
60+
61+
// Example which compiles. (Don't actually include!)
62+
NaiveBayes nb = new NaiveBayes();
63+
nb.setModelType(NaiveBayes.Bernoulli());
64+
65+
sc.stop();
66+
}
67+
}

examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ object SparseNaiveBayes {
8989

9090
println(s"numTraining = $numTraining, numTest = $numTest.")
9191

92+
// Example which compiles. (Don't actually include!)
93+
val nb = new NaiveBayes()
94+
nb.setModelType(NaiveBayes.Bernoulli)
95+
9296
val model = new NaiveBayes().setLambda(params.lambda).run(training)
9397

9498
val prediction = model.predict(test.map(_.features))

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,11 @@ import org.json4s.{DefaultFormats, JValue}
2727
import org.apache.spark.{Logging, SparkContext, SparkException}
2828
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2929
import org.apache.spark.mllib.regression.LabeledPoint
30-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
3130
import org.apache.spark.mllib.util.{Loader, Saveable}
3231
import org.apache.spark.rdd.RDD
3332
import org.apache.spark.sql.{DataFrame, SQLContext}
3433

3534

36-
/**
37-
*
38-
*/
39-
object NaiveBayesModels extends Enumeration {
40-
type NaiveBayesModels = Value
41-
val Multinomial, Bernoulli = Value
42-
43-
implicit def toString(model: NaiveBayesModels): String = {
44-
model.toString
45-
}
46-
}
47-
4835
/**
4936
* Model for Naive Bayes Classifiers.
5037
*
@@ -60,17 +47,18 @@ class NaiveBayesModel private[mllib] (
6047
val labels: Array[Double],
6148
val pi: Array[Double],
6249
val theta: Array[Array[Double]],
63-
val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable {
50+
val modelType: NaiveBayes.ModelType)
51+
extends ClassificationModel with Serializable with Saveable {
6452

6553
def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
66-
this(labels, pi, theta, NaiveBayesModels.Multinomial)
54+
this(labels, pi, theta, NaiveBayes.Multinomial)
6755

6856
private val brzPi = new BDV[Double](pi)
6957
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
7058

7159
private val brzNegTheta: Option[BDM[Double]] = modelType match {
72-
case NaiveBayesModels.Multinomial => None
73-
case NaiveBayesModels.Bernoulli =>
60+
case NaiveBayes.Multinomial => None
61+
case NaiveBayes.Bernoulli =>
7462
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
7563
Option(negTheta)
7664
}
@@ -85,17 +73,17 @@ class NaiveBayesModel private[mllib] (
8573

8674
override def predict(testData: Vector): Double = {
8775
modelType match {
88-
case NaiveBayesModels.Multinomial =>
76+
case NaiveBayes.Multinomial =>
8977
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
90-
case NaiveBayesModels.Bernoulli =>
78+
case NaiveBayes.Bernoulli =>
9179
labels (brzArgmax (brzPi +
9280
(brzTheta - brzNegTheta.get) * testData.toBreeze +
9381
brzSum(brzNegTheta.get, Axis._1)))
9482
}
9583
}
9684

9785
override def save(sc: SparkContext, path: String): Unit = {
98-
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
86+
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
9987
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
10088
}
10189

@@ -147,15 +135,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
147135
val labels = data.getAs[Seq[Double]](0).toArray
148136
val pi = data.getAs[Seq[Double]](1).toArray
149137
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
150-
val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3))
138+
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
151139
new NaiveBayesModel(labels, pi, theta, modelType)
152140
}
153141
}
154142

155143
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
156-
def getModelType(metadata: JValue): NaiveBayesModels = {
144+
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
157145
implicit val formats = DefaultFormats
158-
NaiveBayesModels.withName((metadata \ "modelType").extract[String])
146+
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
159147
}
160148
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
161149
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -191,12 +179,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
191179
* document classification. By making every vector a 0-1 vector, it can also be used as
192180
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
193181
*/
194-
class NaiveBayes private (private var lambda: Double,
195-
var modelType: NaiveBayesModels) extends Serializable with Logging {
182+
class NaiveBayes private (
183+
private var lambda: Double,
184+
var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
196185

197-
def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial)
186+
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
198187

199-
def this() = this(1.0, NaiveBayesModels.Multinomial)
188+
def this() = this(1.0, NaiveBayes.Multinomial)
200189

201190
/** Set the smoothing parameter. Default: 1.0. */
202191
def setLambda(lambda: Double): NaiveBayes = {
@@ -205,7 +194,7 @@ class NaiveBayes private (private var lambda: Double,
205194
}
206195

207196
/** Set the model type. Default: Multinomial. */
208-
def setModelType(model: NaiveBayesModels): NaiveBayes = {
197+
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
209198
this.modelType = model
210199
this
211200
}
@@ -262,8 +251,8 @@ class NaiveBayes private (private var lambda: Double,
262251
labels(i) = label
263252
pi(i) = math.log(n + lambda) - piLogDenom
264253
val thetaLogDenom = modelType match {
265-
case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
266-
case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda)
254+
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
255+
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
267256
}
268257
var j = 0
269258
while (j < numFeatures) {
@@ -330,6 +319,32 @@ object NaiveBayes {
330319
* Multinomial or Bernoulli
331320
*/
332321
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
333-
new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input)
322+
new NaiveBayes(lambda, Multinomial).run(input)
334323
}
324+
325+
sealed abstract class ModelType
326+
327+
object MODELTYPE {
328+
final val MULTINOMIAL_STRING = "multinomial"
329+
final val BERNOULLI_STRING = "bernoulli"
330+
331+
def fromString(modelType: String): ModelType = modelType match {
332+
case MULTINOMIAL_STRING => Multinomial
333+
case BERNOULLI_STRING => Bernoulli
334+
case _ =>
335+
throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType")
336+
}
337+
}
338+
339+
final val ModelType = MODELTYPE
340+
341+
final val Multinomial: ModelType = new ModelType {
342+
override def toString: String = ModelType.MULTINOMIAL_STRING
343+
}
344+
345+
final val Bernoulli: ModelType = new ModelType {
346+
override def toString: String = ModelType.BERNOULLI_STRING
347+
}
348+
335349
}
350+

0 commit comments

Comments
 (0)