Skip to content

Commit 1faf985

Browse files
committed
[SPARK-1406] Added target field to the regression model for completeness
Adjusted unit test to deal with this change
1 parent 3ae8ae5 commit 1faf985

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
7272
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
7373
}
7474

75+
//for completeness add target field
76+
val targetField = FieldName.create("target");
77+
dataDictionary
78+
.withDataFields(
79+
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
80+
)
81+
miningSchema
82+
.withMiningFields(new MiningField(targetField)
83+
.withUsageType(FieldUsageType.TARGET))
84+
7585
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())
7686

7787
pmml.setDataDictionary(dataDictionary)

mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
4545
var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml()
4646
assert(pmml.getHeader().getDescription() === "linear regression")
4747
//check that the number of fields match the weights size
48-
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size)
48+
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size + 1)
4949
//this verify that there is a model attached to the pmml object and the model is a regression one
5050
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
5151
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
@@ -58,7 +58,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
5858
pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml()
5959
assert(pmml.getHeader().getDescription() === "ridge regression")
6060
//check that the number of fields match the weights size
61-
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size)
61+
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size + 1)
6262
//this verify that there is a model attached to the pmml object and the model is a regression one
6363
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
6464
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
@@ -71,7 +71,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
7171
pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml()
7272
assert(pmml.getHeader().getDescription() === "lasso regression")
7373
//check that the number of fields match the weights size
74-
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size)
74+
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size + 1)
7575
//this verify that there is a model attached to the pmml object and the model is a regression one
7676
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
7777
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]

0 commit comments

Comments
 (0)