Skip to content

Commit 7622b0c

Browse files
committed
added comments and fixed style as per rb
1 parent b61b5e2 commit 7622b0c

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ import org.json4s.jackson.JsonMethods._
2525
import org.json4s.{DefaultFormats, JValue}
2626

2727
import org.apache.spark.{Logging, SparkContext, SparkException}
28+
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2829
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2930
import org.apache.spark.mllib.regression.LabeledPoint
30-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
3131
import org.apache.spark.mllib.util.{Loader, Saveable}
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.sql.{DataFrame, SQLContext}
3434

3535

3636
/**
37-
*
37+
* Model types supported in Naive Bayes:
38+
* multinomial and Bernoulli currently supported
3839
*/
3940
object NaiveBayesModels extends Enumeration {
4041
type NaiveBayesModels = Value
@@ -45,6 +46,8 @@ object NaiveBayesModels extends Enumeration {
4546
}
4647
}
4748

49+
50+
4851
/**
4952
* Model for Naive Bayes Classifiers.
5053
*
@@ -55,7 +58,6 @@ object NaiveBayesModels extends Enumeration {
5558
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
5659
* Multinomial or Bernoulli
5760
*/
58-
5961
class NaiveBayesModel private[mllib] (
6062
val labels: Array[Double],
6163
val pi: Array[Double],
@@ -68,11 +70,14 @@ class NaiveBayesModel private[mllib] (
6870
private val brzPi = new BDV[Double](pi)
6971
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
7072

71-
private val brzNegTheta: Option[BDM[Double]] = modelType match {
72-
case NaiveBayesModels.Multinomial => None
73+
//Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
74+
//precomputing log(1.0 - exp(theta)) and its sum for linear algebra application
75+
//of this condition in predict function
76+
private val (brzNegTheta, brzNegThetaSum) = modelType match {
77+
case NaiveBayesModels.Multinomial => (None, None)
7378
case NaiveBayesModels.Bernoulli =>
7479
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
75-
Option(negTheta)
80+
(Option(negTheta), Option(brzSum(brzNegTheta, Axis._1)))
7681
}
7782

7883
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -89,8 +94,7 @@ class NaiveBayesModel private[mllib] (
8994
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
9095
case NaiveBayesModels.Bernoulli =>
9196
labels (brzArgmax (brzPi +
92-
(brzTheta - brzNegTheta.get) * testData.toBreeze +
93-
brzSum(brzNegTheta.get, Axis._1)))
97+
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9498
}
9599
}
96100

@@ -114,10 +118,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
114118
def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
115119

116120
/** Model data for model import/export */
117-
case class Data(labels: Array[Double],
118-
pi: Array[Double],
119-
theta: Array[Array[Double]],
120-
modelType: String)
121+
case class Data(
122+
labels: Array[Double],
123+
pi: Array[Double],
124+
theta: Array[Array[Double]],
125+
modelType: String)
121126

122127
def save(sc: SparkContext, path: String, data: Data): Unit = {
123128
val sqlContext = new SQLContext(sc)
@@ -192,7 +197,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
192197
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
193198
*/
194199
class NaiveBayes private (private var lambda: Double,
195-
var modelType: NaiveBayesModels) extends Serializable with Logging {
200+
private var modelType: NaiveBayesModels) extends Serializable with Logging {
196201

197202
def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial)
198203

@@ -284,7 +289,7 @@ object NaiveBayes {
284289
/**
285290
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
286291
*
287-
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
292+
* This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
288293
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
289294
* document classification.
290295
*
@@ -300,7 +305,7 @@ object NaiveBayes {
300305
/**
301306
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
302307
*
303-
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
308+
* This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
304309
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
305310
* document classification.
306311
*
@@ -316,11 +321,13 @@ object NaiveBayes {
316321
/**
317322
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
318323
*
319-
* This is by default the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle
320-
* all kinds of discrete data. For example, by converting documents into TF-IDF vectors,
321-
* it can be used for document classification. By making every vector a 0-1 vector and
322-
* setting the model type to NaiveBayesModels.Bernoulli, it fits and predicts as
323-
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
324+
* The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]])
325+
* or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle
326+
* discrete count data and can be called by setting the model type to "Multinomial".
327+
* For example, it can be used with word counts or TF_IDF vectors of documents.
328+
* The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
329+
* 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as
330+
* Bernoulli NB.
324331
*
325332
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
326333
* vector or a count vector.

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ package org.apache.spark.mllib.classification
1919

2020
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
2121
import breeze.stats.distributions.Multinomial
22-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2322

2423
import scala.util.Random
2524

2625
import org.scalatest.FunSuite
2726

2827
import org.apache.spark.SparkException
28+
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2929
import org.apache.spark.mllib.linalg.Vectors
3030
import org.apache.spark.mllib.regression.LabeledPoint
3131
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -49,7 +49,7 @@ object NaiveBayesSuite {
4949
theta: Array[Array[Double]], // CXD
5050
nPoints: Int,
5151
seed: Int,
52-
dataModel: NaiveBayesModels = NaiveBayesModels.Multinomial,
52+
dataModel: NaiveBayesModels= NaiveBayesModels.Multinomial,
5353
sample: Int = 10): Seq[LabeledPoint] = {
5454
val D = theta(0).length
5555
val rnd = new Random(seed)
@@ -92,7 +92,10 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
9292
assert(numOfPredictions < input.length / 5)
9393
}
9494

95-
def validateModelFit(piData: Array[Double], thetaData: Array[Array[Double]], model: NaiveBayesModel) = {
95+
def validateModelFit(
96+
piData: Array[Double],
97+
thetaData: Array[Array[Double]],
98+
model: NaiveBayesModel) = {
9699
def closeFit(d1: Double, d2: Double, precision: Double): Boolean = {
97100
(d1 - d2).abs <= precision
98101
}
@@ -117,14 +120,20 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
117120
Array(0.10, 0.10, 0.70, 0.10) // label 2
118121
).map(_.map(math.log))
119122

120-
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, NaiveBayesModels.Multinomial)
123+
val testData = NaiveBayesSuite.generateNaiveBayesInput(
124+
pi, theta, nPoints, 42, NaiveBayesModels.Multinomial)
121125
val testRDD = sc.parallelize(testData, 2)
122126
testRDD.cache()
123127

124128
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
125129
validateModelFit(pi, theta, model)
126130

127-
val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17, NaiveBayesModels.Multinomial)
131+
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
132+
pi,
133+
theta,
134+
nPoints,
135+
17,
136+
NaiveBayesModels.Multinomial)
128137
val validationRDD = sc.parallelize(validationData, 2)
129138

130139
// Test prediction on RDD.
@@ -144,14 +153,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
144153
Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
145154
).map(_.map(math.log))
146155

147-
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 45, NaiveBayesModels.Bernoulli)
156+
val testData = NaiveBayesSuite.generateNaiveBayesInput(
157+
pi,
158+
theta,
159+
nPoints,
160+
45,
161+
NaiveBayesModels.Bernoulli)
148162
val testRDD = sc.parallelize(testData, 2)
149163
testRDD.cache()
150164

151165
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
152166
validateModelFit(pi, theta, model)
153167

154-
val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 20, NaiveBayesModels.Bernoulli)
168+
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
169+
pi,
170+
theta,
171+
nPoints,
172+
20,
173+
NaiveBayesModels.Bernoulli)
155174
val validationRDD = sc.parallelize(validationData, 2)
156175

157176
// Test prediction on RDD.
@@ -218,8 +237,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
218237
LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
219238
}
220239
}
221-
// If we serialize data directly in the task closure, the size of the serialized task would be
222-
// greater than 1MB and hence Spark would throw an error.
240+
// If we serialize data directly in the task closure, the size of the serialized task
241+
// would be greater than 1MB and hence Spark would throw an error.
223242
val model = NaiveBayes.train(examples)
224243
val predictions = model.predict(examples.map(_.features))
225244
}

0 commit comments

Comments
 (0)