Skip to content

Commit b01df54

Browse files
committed
allow to change or clear threshold in LR and SVM
add more comments to MLUtils.fastSquaredDistance
1 parent 4addc50 commit b01df54

File tree

8 files changed

+108
-44
lines changed

8 files changed

+108
-44
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ class PythonMLLibAPI extends Serializable {
110110

111111
private def trainRegressionModel(
112112
trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
113-
dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
114-
java.util.LinkedList[java.lang.Object] = {
113+
dataBytesJRDD: JavaRDD[Array[Byte]],
114+
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
115115
val data = dataBytesJRDD.rdd.map(xBytes => {
116116
val x = deserializeDoubleVector(xBytes)
117117
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
@@ -238,9 +238,9 @@ class PythonMLLibAPI extends Serializable {
238238
/**
239239
* Java stub for NaiveBayes.train()
240240
*/
241-
def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double)
242-
: java.util.List[java.lang.Object] =
243-
{
241+
def trainNaiveBayes(
242+
dataBytesJRDD: JavaRDD[Array[Byte]],
243+
lambda: Double): java.util.List[java.lang.Object] = {
244244
val data = dataBytesJRDD.rdd.map(xBytes => {
245245
val x = deserializeDoubleVector(xBytes)
246246
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
@@ -256,9 +256,12 @@ class PythonMLLibAPI extends Serializable {
256256
/**
257257
* Java stub for Python mllib KMeans.train()
258258
*/
259-
def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
260-
maxIterations: Int, runs: Int, initializationMode: String):
261-
java.util.List[java.lang.Object] = {
259+
def trainKMeansModel(
260+
dataBytesJRDD: JavaRDD[Array[Byte]],
261+
k: Int,
262+
maxIterations: Int,
263+
runs: Int,
264+
initializationMode: String): java.util.List[java.lang.Object] = {
262265
val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes)))
263266
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
264267
val ret = new java.util.LinkedList[java.lang.Object]()
@@ -311,8 +314,12 @@ class PythonMLLibAPI extends Serializable {
311314
* needs to be taken in the Python code to ensure it gets freed on exit; see
312315
* the Py4J documentation.
313316
*/
314-
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
315-
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
317+
def trainALSModel(
318+
ratingsBytesJRDD: JavaRDD[Array[Byte]],
319+
rank: Int,
320+
iterations: Int,
321+
lambda: Double,
322+
blocks: Int): MatrixFactorizationModel = {
316323
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
317324
ALS.train(ratings, rank, iterations, lambda, blocks)
318325
}
@@ -323,8 +330,13 @@ class PythonMLLibAPI extends Serializable {
323330
* Extra care needs to be taken in the Python code to ensure it gets freed on
324331
* exit; see the Py4J documentation.
325332
*/
326-
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
327-
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
333+
def trainImplicitALSModel(
334+
ratingsBytesJRDD: JavaRDD[Array[Byte]],
335+
rank: Int,
336+
iterations: Int,
337+
lambda: Double,
338+
blocks: Int,
339+
alpha: Double): MatrixFactorizationModel = {
328340
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
329341
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
330342
}

mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,27 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20-
import org.apache.spark.rdd.RDD
2120
import org.apache.spark.mllib.linalg.Vector
21+
import org.apache.spark.rdd.RDD
2222

23+
/**
24+
* Represents a classification model that predicts to which of a set of categories an example
25+
* belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
26+
*/
2327
trait ClassificationModel extends Serializable {
2428
/**
2529
* Predict values for the given data set using the model trained.
2630
*
2731
* @param testData RDD representing data points to be predicted
28-
* @return RDD[Int] where each entry contains the corresponding prediction
32+
* @return an RDD[Double] where each entry contains the corresponding prediction
2933
*/
3034
def predict(testData: RDD[Vector]): RDD[Double]
3135

3236
/**
3337
* Predict values for a single data point using the model trained.
3438
*
3539
* @param testData array representing a single data point
36-
* @return Int prediction from the trained model
40+
* @return predicted category from the trained model
3741
*/
3842
def predict(testData: Vector): Double
3943
}

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20-
import scala.math.round
21-
2220
import org.apache.spark.SparkContext
23-
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.mllib.linalg.Vector
2422
import org.apache.spark.mllib.optimization._
2523
import org.apache.spark.mllib.regression._
26-
import org.apache.spark.mllib.util.MLUtils
27-
import org.apache.spark.mllib.util.DataValidators
28-
import org.apache.spark.mllib.linalg.Vector
24+
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
25+
import org.apache.spark.rdd.RDD
2926

3027
/**
3128
* Classification model trained using Logistic Regression.
@@ -36,13 +33,36 @@ import org.apache.spark.mllib.linalg.Vector
3633
class LogisticRegressionModel(
3734
override val weights: Vector,
3835
override val intercept: Double)
39-
extends GeneralizedLinearModel(weights, intercept)
40-
with ClassificationModel with Serializable {
36+
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
37+
38+
private var threshold: Option[Double] = Some(0.5)
39+
40+
/**
41+
* Sets the threshold that separates positive predictions from negative predictions. An example
42+
* with prediction score greater than or equal to this threshold is identified as an positive,
43+
* and negative otherwise. The default value is 0.5.
44+
*/
45+
def setThreshold(threshold: Double): this.type = {
46+
this.threshold = Some(threshold)
47+
this
48+
}
49+
50+
/**
51+
* Clears the threshold so that `predict` will output raw prediction scores.
52+
*/
53+
def clearThreshold(): this.type = {
54+
threshold = None
55+
this
56+
}
4157

4258
override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
4359
intercept: Double) = {
4460
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
45-
round(1.0/ (1.0 + math.exp(margin * -1)))
61+
val score = 1.0/ (1.0 + math.exp(-margin))
62+
threshold match {
63+
case Some(t) => if (score < t) 0.0 else 1.0
64+
case None => score
65+
}
4666
}
4767
}
4868

@@ -55,16 +75,15 @@ class LogisticRegressionWithSGD private (
5575
var numIterations: Int,
5676
var regParam: Double,
5777
var miniBatchFraction: Double)
58-
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
59-
with Serializable {
78+
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
6079

6180
val gradient = new LogisticGradient()
6281
val updater = new SimpleUpdater()
6382
override val optimizer = new GradientDescent(gradient, updater)
64-
.setStepSize(stepSize)
65-
.setNumIterations(numIterations)
66-
.setRegParam(regParam)
67-
.setMiniBatchFraction(miniBatchFraction)
83+
.setStepSize(stepSize)
84+
.setNumIterations(numIterations)
85+
.setRegParam(regParam)
86+
.setMiniBatchFraction(miniBatchFraction)
6887
override val validators = List(DataValidators.classificationLabels)
6988

7089
/**

mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
package org.apache.spark.mllib.classification
1919

2020
import org.apache.spark.SparkContext
21-
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.mllib.linalg.Vector
2222
import org.apache.spark.mllib.optimization._
2323
import org.apache.spark.mllib.regression._
24-
import org.apache.spark.mllib.util.MLUtils
25-
import org.apache.spark.mllib.util.DataValidators
26-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
24+
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
25+
import org.apache.spark.rdd.RDD
2726

2827
/**
2928
* Model for Support Vector Machines (SVMs).
@@ -34,13 +33,35 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
3433
class SVMModel(
3534
override val weights: Vector,
3635
override val intercept: Double)
37-
extends GeneralizedLinearModel(weights, intercept)
38-
with ClassificationModel with Serializable {
36+
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
37+
38+
private var threshold: Option[Double] = Some(0.0)
39+
40+
/**
41+
* Sets the threshold that separates positive predictions from negative predictions. An example
42+
* with prediction score greater than or equal to this threshold is identified as an positive,
43+
* and negative otherwise. The default value is 0.0.
44+
*/
45+
def setThreshold(threshold: Double): this.type = {
46+
this.threshold = Some(threshold)
47+
this
48+
}
49+
50+
/**
51+
* Clears the threshold so that `predict` will output raw prediction scores.
52+
*/
53+
def clearThreshold(): this.type = {
54+
threshold = None
55+
this
56+
}
3957

4058
override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
4159
intercept: Double) = {
4260
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
43-
if (margin < 0) 0.0 else 1.0
61+
threshold match {
62+
case Some(t) => if (margin < 0) 0.0 else 1.0
63+
case None => margin
64+
}
4465
}
4566
}
4667

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class KMeans private (
4242
var runs: Int,
4343
var initializationMode: String,
4444
var initializationSteps: Int,
45-
var epsilon: Double)
46-
extends Serializable with Logging {
45+
var epsilon: Double) extends Serializable with Logging {
4746
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
4847

4948
/** Set the number of clusters to create (k). Default: 2. */

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ object MLUtils {
150150
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
151151
val normDiff = norm1 - norm2
152152
var sqDist = 0.0
153+
/*
154+
* The relative error is
155+
* <pre>
156+
* EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
157+
* </pre>
158+
* which is bounded by
159+
* <pre>
160+
* 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
161+
* </pre>
162+
* The bound doesn't need the inner product, so we can use it as a sufficient condition to
163+
* check quickly whether the inner product approach is accurate.
164+
*/
153165
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
154166
if (precisionBound1 < precision) {
155167
sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)

mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.classification;
1919

20-
2120
import java.io.Serializable;
2221
import java.util.List;
2322

@@ -28,7 +27,6 @@
2827

2928
import org.apache.spark.api.java.JavaRDD;
3029
import org.apache.spark.api.java.JavaSparkContext;
31-
3230
import org.apache.spark.mllib.regression.LabeledPoint;
3331

3432
public class JavaSVMSuite implements Serializable {
@@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() {
9492
int numAccurate = validatePrediction(validationData, model);
9593
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
9694
}
97-
9895
}

mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext {
150150
}
151151

152152
intercept[SparkException] {
153-
val model = SVMWithSGD.train(testRDDInvalid, 100)
153+
SVMWithSGD.train(testRDDInvalid, 100)
154154
}
155155

156156
// Turning off data validation should not throw an exception
157-
val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
157+
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
158158
}
159159
}

0 commit comments

Comments
 (0)