Skip to content

Commit 63eff48

Browse files
committed
Added getLambda to Scala NaiveBayes
1 parent 7c7d2d5 commit 63eff48

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
166166
this
167167
}
168168

169+
/** Get the smoothing parameter. Default: 1.0. */
170+
def getLambda: Double = lambda
171+
169172
/**
170173
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
171174
*

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
8585
assert(numOfPredictions < input.length / 5)
8686
}
8787

88+
test("get, set params") {
89+
val nb = new NaiveBayes()
90+
nb.setLambda(2.0)
91+
assert(nb.getLambda == 2.0)
92+
nb.setLambda(3.0)
93+
assert(nb.getLambda == 3.0)
94+
}
95+
8896
test("Naive Bayes") {
8997
val nPoints = 10000
9098

0 commit comments

Comments
 (0)