Skip to content

Commit 03bc3a5

Browse files
committed
added logistic regression
1 parent da2ec11 commit 03bc3a5

File tree

6 files changed

+196
-8
lines changed

6 files changed

+196
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
package org.apache.spark.mllib.export
1919

20+
import org.apache.spark.mllib.classification.LogisticRegressionModel
2021
import org.apache.spark.mllib.classification.SVMModel
2122
import org.apache.spark.mllib.clustering.KMeansModel
2223
import org.apache.spark.mllib.export.ModelExportType.ModelExportType
2324
import org.apache.spark.mllib.export.ModelExportType.PMML
2425
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
2526
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
27+
import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport
2628
import org.apache.spark.mllib.regression.LassoModel
2729
import org.apache.spark.mllib.regression.LinearRegressionModel
2830
import org.apache.spark.mllib.regression.RidgeRegressionModel
@@ -46,7 +48,14 @@ private[mllib] object ModelExportFactory {
4648
case lassoRegression: LassoModel =>
4749
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
4850
case svm: SVMModel =>
49-
new GeneralizedLinearPMMLModelExport(svm, "linear SVM")
51+
new GeneralizedLinearPMMLModelExport(
52+
svm,
53+
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
54+
case logisticRegression: LogisticRegressionModel =>
55+
new LogisticRegressionPMMLModelExport(
56+
logisticRegression,
57+
"logistic regression: if predicted value > 0.5, "
58+
+ "the outcome is positive, or negative otherwise")
5059
case _ =>
5160
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
5261
}

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.dmg.pmml.RegressionTable
3333
import org.apache.spark.mllib.regression.GeneralizedLinearModel
3434

3535
/**
36-
* PMML Model Export for GeneralizedLinear abstract class
36+
* PMML Model Export for GeneralizedLinearModel abstract class
3737
*/
3838
private[mllib] class GeneralizedLinearPMMLModelExport(
3939
model : GeneralizedLinearModel,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export.pmml
19+
20+
import org.dmg.pmml.DataDictionary
21+
import org.dmg.pmml.DataField
22+
import org.dmg.pmml.DataType
23+
import org.dmg.pmml.FieldName
24+
import org.dmg.pmml.FieldUsageType
25+
import org.dmg.pmml.MiningField
26+
import org.dmg.pmml.MiningFunctionType
27+
import org.dmg.pmml.MiningSchema
28+
import org.dmg.pmml.NumericPredictor
29+
import org.dmg.pmml.OpType
30+
import org.dmg.pmml.RegressionModel
31+
import org.dmg.pmml.RegressionTable
32+
import org.apache.spark.mllib.classification.LogisticRegressionModel
33+
import org.dmg.pmml.RegressionNormalizationMethodType
34+
35+
/**
36+
* PMML Model Export for LogisticRegressionModel class
37+
*/
38+
private[mllib] class LogisticRegressionPMMLModelExport(
39+
model : LogisticRegressionModel,
40+
description : String)
41+
extends PMMLModelExport{
42+
43+
/**
44+
* Export the input LogisticRegressionModel model to PMML format
45+
*/
46+
populateLogisticRegressionPMML(model)
47+
48+
private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = {
49+
50+
pmml.getHeader().setDescription(description)
51+
52+
if(model.weights.size > 0){
53+
54+
val fields = new Array[FieldName](model.weights.size)
55+
56+
val dataDictionary = new DataDictionary()
57+
58+
val miningSchema = new MiningSchema()
59+
60+
val regressionTableYES = new RegressionTable(model.intercept)
61+
.withTargetCategory("YES")
62+
63+
val regressionTableNO = new RegressionTable(0.0)
64+
.withTargetCategory("NO")
65+
66+
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION)
67+
.withModelName(description)
68+
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
69+
.withRegressionTables(regressionTableYES, regressionTableNO)
70+
71+
for ( i <- 0 until model.weights.size) {
72+
fields(i) = FieldName.create("field_" + i)
73+
dataDictionary
74+
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
75+
miningSchema
76+
.withMiningFields(new MiningField(fields(i))
77+
.withUsageType(FieldUsageType.ACTIVE))
78+
regressionTableYES
79+
.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
80+
}
81+
82+
// add target field
83+
val targetField = FieldName.create("target");
84+
dataDictionary
85+
.withDataFields(
86+
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
87+
)
88+
miningSchema
89+
.withMiningFields(new MiningField(targetField)
90+
.withUsageType(FieldUsageType.TARGET))
91+
92+
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())
93+
94+
pmml.setDataDictionary(dataDictionary)
95+
pmml.withModels(regressionModel)
96+
97+
}
98+
99+
}
100+
101+
}

mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.export
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.mllib.classification.LogisticRegressionModel
2223
import org.apache.spark.mllib.classification.SVMModel
2324
import org.apache.spark.mllib.clustering.KMeansModel
2425
import org.apache.spark.mllib.linalg.Vectors
@@ -28,6 +29,7 @@ import org.apache.spark.mllib.regression.RidgeRegressionModel
2829
import org.apache.spark.mllib.util.LinearDataGenerator
2930
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
3031
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
32+
import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport
3133

3234
class ModelExportFactorySuite extends FunSuite{
3335

@@ -55,10 +57,10 @@ class ModelExportFactorySuite extends FunSuite{
5557
//arrange
5658
val linearInput = LinearDataGenerator.generateLinearInput(
5759
3.0, Array(10.0, 10.0), 1, 17)
58-
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
59-
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
60-
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
61-
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);
60+
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
61+
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
62+
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
63+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
6264

6365
//act
6466
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
@@ -82,6 +84,20 @@ class ModelExportFactorySuite extends FunSuite{
8284

8385
}
8486

87+
test("ModelExportFactory create LogisticRegressionPMMLModelExport when passing a LogisticRegressionModel") {
88+
89+
//arrange
90+
val linearInput = LinearDataGenerator.generateLinearInput(
91+
3.0, Array(10.0, 10.0), 1, 17)
92+
val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label);
93+
94+
//act
95+
val logisticRegressionModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML)
96+
//assert
97+
assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
98+
99+
}
100+
85101
test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
86102

87103
//arrange

mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
8484
//assert that the PMML format is as expected
8585
assert(svmModelExport.isInstanceOf[PMMLModelExport])
8686
pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml()
87-
assert(pmml.getHeader().getDescription() === "linear SVM")
87+
assert(pmml.getHeader().getDescription() === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
8888
//check that the number of fields match the weights size
8989
assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1)
9090
//this verify that there is a model attached to the pmml object and the model is a regression one
@@ -96,7 +96,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
9696
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
9797
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
9898
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
99-
//ModelExporter.toPMML(svmModel,"/tmp/svm.xml")
99+
//ModelExporter.toPMML(svmModel,"/tmp/linearsvm.xml")
100100

101101
}
102102

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export.pmml
19+
20+
import org.dmg.pmml.RegressionModel
21+
import org.scalatest.FunSuite
22+
23+
import org.apache.spark.mllib.classification.LogisticRegressionModel
24+
import org.apache.spark.mllib.export.ModelExportFactory
25+
import org.apache.spark.mllib.export.ModelExportType
26+
import org.apache.spark.mllib.util.LinearDataGenerator
27+
28+
class LogisticRegressionPMMLModelExportSuite extends FunSuite{
29+
30+
test("LogisticRegressionPMMLModelExport generate PMML format") {
31+
32+
//arrange models to test
33+
val linearInput = LinearDataGenerator.generateLinearInput(
34+
3.0, Array(10.0, 10.0), 1, 17)
35+
val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label);
36+
37+
//act by exporting the model to the PMML format
38+
val logisticModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML)
39+
//assert that the PMML format is as expected
40+
assert(logisticModelExport.isInstanceOf[PMMLModelExport])
41+
var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml()
42+
assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise")
43+
//check that the number of fields match the weights size
44+
assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1)
45+
//this verify that there is a model attached to the pmml object and the model is a regression one
46+
//it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights
47+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
48+
.getRegressionTables().get(0).getTargetCategory() === "YES")
49+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
50+
.getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size)
51+
//verify if there is a second table with target category NO and no predictors
52+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
53+
.getRegressionTables().get(1).getTargetCategory() === "NO")
54+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
55+
.getRegressionTables().get(1).getNumericPredictors().size() === 0)
56+
57+
//manual checking
58+
//ModelExporter.toPMML(logisticRegressionModel,"/tmp/logisticregression.xml")
59+
60+
}
61+
62+
}

0 commit comments

Comments
 (0)