Skip to content

Commit 7a5e0ec

Browse files
committed
[SPARK-1406] Binary classification for SVM and Logistic Regression
1 parent cfcb596 commit 7a5e0ec

File tree

6 files changed

+77
-48
lines changed

6 files changed

+77
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,46 @@ import scala.{Array => SArray}
2121

2222
import org.dmg.pmml._
2323

24-
import org.apache.spark.mllib.classification.LogisticRegressionModel
24+
import org.apache.spark.mllib.regression.GeneralizedLinearModel
2525

2626
/**
27-
* PMML Model Export for LogisticRegressionModel class
27+
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
2828
*/
29-
private[mllib] class LogisticRegressionPMMLModelExport(
30-
model : LogisticRegressionModel,
31-
description : String)
29+
private[mllib] class BinaryClassificationPMMLModelExport(
30+
model : GeneralizedLinearModel,
31+
description : String,
32+
normalizationMethod : RegressionNormalizationMethodType,
33+
threshold: Double)
3234
extends PMMLModelExport {
3335

34-
populateLogisticRegressionPMML(model)
36+
populateBinaryClassificationPMML()
3537

3638
/**
37-
* Export the input LogisticRegressionModel model to PMML format
39+
* Export the input LogisticRegressionModel or SVMModel to PMML format.
3840
*/
39-
private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = {
41+
private def populateBinaryClassificationPMML(): Unit = {
4042
pmml.getHeader.setDescription(description)
4143

4244
if (model.weights.size > 0) {
4345
val fields = new SArray[FieldName](model.weights.size)
4446
val dataDictionary = new DataDictionary
4547
val miningSchema = new MiningSchema
4648
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
47-
val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0")
49+
var interceptNO = threshold
50+
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
51+
if (threshold <= 0)
52+
interceptNO = 1000
53+
else if (threshold >= 1)
54+
interceptNO = -1000
55+
else
56+
interceptNO = -math.log(1/threshold -1)
57+
}
58+
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
4859
val regressionModel = new RegressionModel()
4960
.withFunctionName(MiningFunctionType.CLASSIFICATION)
5061
.withMiningSchema(miningSchema)
5162
.withModelName(description)
52-
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
63+
.withNormalizationMethod(normalizationMethod)
5364
.withRegressionTables(regressionTableYES, regressionTableNO)
5465

5566
for (i <- 0 until model.weights.size) {

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.pmml.export
1919

20+
import org.dmg.pmml.RegressionNormalizationMethodType
21+
2022
import org.apache.spark.mllib.classification.LogisticRegressionModel
2123
import org.apache.spark.mllib.classification.SVMModel
2224
import org.apache.spark.mllib.clustering.KMeansModel
@@ -41,11 +43,14 @@ private[mllib] object PMMLModelExportFactory {
4143
case lasso: LassoModel =>
4244
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
4345
case svm: SVMModel =>
44-
new GeneralizedLinearPMMLModelExport(svm,
45-
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
46+
new BinaryClassificationPMMLModelExport(
47+
svm, "linear SVM", RegressionNormalizationMethodType.NONE,
48+
svm.getThreshold.getOrElse(0.0))
4649
case logistic: LogisticRegressionModel =>
47-
if(logistic.numClasses == 2)
48-
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
50+
if (logistic.numClasses == 2)
51+
new BinaryClassificationPMMLModelExport(
52+
logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT,
53+
logistic.getThreshold.getOrElse(0.5))
4954
else
5055
throw new IllegalArgumentException(
5156
"PMML Export not supported for Multinomial Logistic Regression")

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala renamed to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
package org.apache.spark.mllib.pmml.export
1919

2020
import org.dmg.pmml.RegressionModel
21+
import org.dmg.pmml.RegressionNormalizationMethodType
2122
import org.scalatest.FunSuite
2223

2324
import org.apache.spark.mllib.classification.LogisticRegressionModel
25+
import org.apache.spark.mllib.classification.SVMModel
2426
import org.apache.spark.mllib.util.LinearDataGenerator
2527

26-
class LogisticRegressionPMMLModelExportSuite extends FunSuite {
28+
class BinaryClassificationPMMLModelExportSuite extends FunSuite {
2729

28-
test("LogisticRegressionPMMLModelExport generate PMML format") {
30+
test("logistic regression PMML export") {
2931
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
3032
val logisticRegressionModel =
3133
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
@@ -48,5 +50,35 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite {
4850
// verify if there is a second table with target category 0 and no predictors
4951
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
5052
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
53+
// ensure logistic regression has normalization method set to LOGIT
54+
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
5155
}
56+
57+
test("linear SVM PMML export") {
58+
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
59+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
60+
61+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
62+
63+
// assert that the PMML format is as expected
64+
assert(svmModelExport.isInstanceOf[PMMLModelExport])
65+
val pmml = svmModelExport.getPmml
66+
assert(pmml.getHeader.getDescription
67+
=== "linear SVM")
68+
// check that the number of fields match the weights size
69+
assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
70+
// This verify that there is a model attached to the pmml object and the model is a regression
71+
// one. It also verifies that the pmml model has a regression table (for target category 1)
72+
// with the same number of predictors of the model weights.
73+
val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
74+
assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
75+
assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
76+
=== svmModel.weights.size)
77+
// verify if there is a second table with target category 0 and no predictors
78+
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
79+
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
80+
// ensure linear SVM has normalization method set to NONE
81+
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
82+
}
83+
5284
}

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ package org.apache.spark.mllib.pmml.export
2020
import org.dmg.pmml.RegressionModel
2121
import org.scalatest.FunSuite
2222

23-
import org.apache.spark.mllib.classification.SVMModel
2423
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
2524
import org.apache.spark.mllib.util.LinearDataGenerator
2625

2726
class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
2827

29-
test("linear regression pmml export") {
28+
test("linear regression PMML export") {
3029
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
3130
val linearRegressionModel =
3231
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
@@ -45,7 +44,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
4544
=== linearRegressionModel.weights.size)
4645
}
4746

48-
test("ridge regression pmml export") {
47+
test("ridge regression PMML export") {
4948
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
5049
val ridgeRegressionModel =
5150
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
@@ -64,7 +63,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
6463
=== ridgeRegressionModel.weights.size)
6564
}
6665

67-
test("lasso pmml export") {
66+
test("lasso PMML export") {
6867
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
6968
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
7069
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
@@ -82,22 +81,4 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
8281
=== lassoModel.weights.size)
8382
}
8483

85-
test("svm pmml export") {
86-
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
87-
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
88-
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
89-
// assert that the PMML format is as expected
90-
assert(svmModelExport.isInstanceOf[PMMLModelExport])
91-
val pmml = svmModelExport.getPmml
92-
assert(pmml.getHeader.getDescription
93-
=== "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
94-
// check that the number of fields match the weights size
95-
assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
96-
// This verify that there is a model attached to the pmml object and the model is a regression
97-
// one. It also verifies that the pmml model has a regression table with the same number of
98-
// predictors of the model weights.
99-
val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
100-
assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
101-
=== svmModel.weights.size)
102-
}
10384
}

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
4545
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
4646
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
4747
}
48+
4849
}

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class PMMLModelExportFactorySuite extends FunSuite {
4040
}
4141

4242
test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
43-
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
43+
+ "LinearRegressionModel, RidgeRegressionModel or LassoModel") {
4444
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
4545

4646
val linearRegressionModel =
@@ -56,22 +56,21 @@ class PMMLModelExportFactorySuite extends FunSuite {
5656
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
5757
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
5858
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
59-
60-
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
61-
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
62-
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
6359
}
6460

65-
test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport "
66-
+ "when passing a LogisticRegressionModel") {
61+
test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
62+
+ "when passing a LogisticRegressionModel or SVMModel") {
6763
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
64+
6865
val logisticRegressionModel =
6966
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
70-
7167
val logisticRegressionModelExport =
7268
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
73-
74-
assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
69+
assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
70+
71+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
72+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
73+
assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
7574
}
7675

7776
test("PMMLModelExportFactory throw IllegalArgumentException "

0 commit comments

Comments
 (0)