Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 8fe12bb

Browse files
committed
[SPARK-1406] Adjusted logistic regression export description and target
categories
1 parent 03bc3a5 commit 8fe12bb

File tree

3 files changed

+9
-12
lines changed

3 files changed

+9
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@ private[mllib] object ModelExportFactory {
5252
svm,
5353
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
5454
case logisticRegression: LogisticRegressionModel =>
55-
new LogisticRegressionPMMLModelExport(
56-
logisticRegression,
57-
"logistic regression: if predicted value > 0.5, "
58-
+ "the outcome is positive, or negative otherwise")
55+
new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression")
5956
case _ =>
6057
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
6158
}

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ private[mllib] class LogisticRegressionPMMLModelExport(
5858
val miningSchema = new MiningSchema()
5959

6060
val regressionTableYES = new RegressionTable(model.intercept)
61-
.withTargetCategory("YES")
61+
.withTargetCategory("1")
6262

6363
val regressionTableNO = new RegressionTable(0.0)
64-
.withTargetCategory("NO")
64+
.withTargetCategory("0")
6565

6666
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION)
6767
.withModelName(description)
@@ -83,7 +83,7 @@ private[mllib] class LogisticRegressionPMMLModelExport(
8383
val targetField = FieldName.create("target");
8484
dataDictionary
8585
.withDataFields(
86-
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
86+
new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)
8787
)
8888
miningSchema
8989
.withMiningFields(new MiningField(targetField)

mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{
3939
//assert that the PMML format is as expected
4040
assert(logisticModelExport.isInstanceOf[PMMLModelExport])
4141
var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml()
42-
assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise")
42+
assert(pmml.getHeader().getDescription() === "logistic regression")
4343
//check that the number of fields match the weights size
4444
assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1)
4545
//this verify that there is a model attached to the pmml object and the model is a regression one
46-
//it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights
46+
//it also verifies that the pmml model has a regression table (for target category 1) with the same number of predictors of the model weights
4747
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
48-
.getRegressionTables().get(0).getTargetCategory() === "YES")
48+
.getRegressionTables().get(0).getTargetCategory() === "1")
4949
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
5050
.getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size)
51-
//verify if there is a second table with target category NO and no predictors
51+
//verify if there is a second table with target category 0 and no predictors
5252
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
53-
.getRegressionTables().get(1).getTargetCategory() === "NO")
53+
.getRegressionTables().get(1).getTargetCategory() === "0")
5454
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
5555
.getRegressionTables().get(1).getNumericPredictors().size() === 0)
5656

0 commit comments

Comments
 (0)