Skip to content

Commit 78515ec

Browse files
committed
[SPARK-1406] added pmml export for LinearRegressionModel,
RidgeRegressionModel and LassoModel
1 parent e29dfb9 commit 78515ec

File tree

5 files changed

+226
-2
lines changed

5 files changed

+226
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ package org.apache.spark.mllib.export
2020
import org.apache.spark.mllib.clustering.KMeansModel
2121
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
2222
import org.apache.spark.mllib.export.ModelExportType._
23+
import org.apache.spark.mllib.regression.LinearRegressionModel
24+
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
25+
import org.apache.spark.mllib.regression.RidgeRegressionModel
26+
import org.apache.spark.mllib.regression.LassoModel
2327

2428
private[mllib] object ModelExportFactory {
2529

@@ -31,7 +35,14 @@ private[mllib] object ModelExportFactory {
3135
def createModelExport(model: Any, exportType: ModelExportType): ModelExport = {
3236
return exportType match{
3337
case PMML => model match{
34-
case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans)
38+
case kmeans: KMeansModel =>
39+
new KMeansPMMLModelExport(kmeans)
40+
case linearRegression: LinearRegressionModel =>
41+
new GeneralizedLinearPMMLModelExport(linearRegression, "linear regression")
42+
case ridgeRegression: RidgeRegressionModel =>
43+
new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression")
44+
case lassoRegression: LassoModel =>
45+
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
3546
case _ =>
3647
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
3748
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export.pmml
19+
20+
import org.dmg.pmml.Array.Type
21+
import org.dmg.pmml.Cluster
22+
import org.dmg.pmml.ClusteringField
23+
import org.dmg.pmml.ClusteringModel
24+
import org.dmg.pmml.ClusteringModel.ModelClass
25+
import org.dmg.pmml.CompareFunctionType
26+
import org.dmg.pmml.ComparisonMeasure
27+
import org.dmg.pmml.ComparisonMeasure.Kind
28+
import org.dmg.pmml.DataDictionary
29+
import org.dmg.pmml.DataField
30+
import org.dmg.pmml.DataType
31+
import org.dmg.pmml.FieldName
32+
import org.dmg.pmml.FieldUsageType
33+
import org.dmg.pmml.MiningField
34+
import org.dmg.pmml.MiningFunctionType
35+
import org.dmg.pmml.MiningSchema
36+
import org.dmg.pmml.OpType
37+
import org.dmg.pmml.SquaredEuclidean
38+
import org.apache.spark.mllib.clustering.KMeansModel
39+
import org.apache.spark.mllib.regression.LinearRegressionModel
40+
import org.apache.spark.mllib.regression.GeneralizedLinearModel
41+
import org.dmg.pmml.RegressionModel
42+
import org.dmg.pmml.RegressionTable
43+
import org.dmg.pmml.NumericPredictor
44+
45+
/**
46+
* PMML Model Export for GeneralizedLinear abstract class
47+
*/
48+
private[mllib] class GeneralizedLinearPMMLModelExport(
49+
model : GeneralizedLinearModel,
50+
description : String)
51+
extends PMMLModelExport{
52+
53+
/**
54+
* Export the input GeneralizedLinearModel model to PMML format
55+
*/
56+
populateGeneralizedLinearPMML(model)
57+
58+
private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = {
59+
60+
pmml.getHeader().setDescription(description)
61+
62+
if(model.weights.size > 0){
63+
64+
val fields = new Array[FieldName](model.weights.size)
65+
66+
val dataDictionary = new DataDictionary()
67+
68+
val miningSchema = new MiningSchema()
69+
70+
val regressionTable = new RegressionTable(model.intercept)
71+
72+
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
73+
.withModelName(description).withRegressionTables(regressionTable)
74+
75+
for ( i <- 0 until model.weights.size) {
76+
fields(i) = FieldName.create("field_" + i)
77+
dataDictionary
78+
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
79+
miningSchema
80+
.withMiningFields(new MiningField(fields(i))
81+
.withUsageType(FieldUsageType.ACTIVE))
82+
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
83+
}
84+
85+
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())
86+
87+
pmml.setDataDictionary(dataDictionary)
88+
pmml.withModels(regressionModel)
89+
90+
}
91+
92+
}
93+
94+
}

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode
7171
MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
7272
.withModelName("k-means")
7373

74-
for ( i <- 0 to (clusterCenter.size - 1)) {
74+
for ( i <- 0 until clusterCenter.size) {
7575
fields(i) = FieldName.create("field_" + i)
7676
dataDictionary
7777
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))

mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ import org.apache.spark.mllib.clustering.KMeansModel
2222
import org.apache.spark.mllib.linalg.Vectors
2323
import org.scalatest.FunSuite
2424
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
25+
import org.apache.spark.mllib.util.LinearDataGenerator
26+
import org.apache.spark.mllib.regression.LinearRegressionModel
27+
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
28+
import org.apache.spark.mllib.regression.LassoModel
29+
import org.apache.spark.mllib.regression.RidgeRegressionModel
2530

2631
class ModelExportFactorySuite extends FunSuite{
2732

@@ -43,6 +48,33 @@ class ModelExportFactorySuite extends FunSuite{
4348

4449
}
4550

51+
test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a"
52+
+"LinearRegressionModel, RidgeRegressionModel or LassoModel") {
53+
54+
//arrange
55+
val linearInput = LinearDataGenerator.generateLinearInput(
56+
3.0, Array(10.0, 10.0), 1, 17)
57+
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
58+
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
59+
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
60+
61+
//act
62+
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
63+
//assert
64+
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
65+
66+
//act
67+
val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML)
68+
//assert
69+
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
70+
71+
//act
72+
val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML)
73+
//assert
74+
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
75+
76+
}
77+
4678
test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
4779

4880
//arrange
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export.pmml
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.mllib.export.ModelExportFactory
22+
import org.apache.spark.mllib.export.ModelExportType
23+
import org.apache.spark.mllib.regression.LassoModel
24+
import org.apache.spark.mllib.regression.LinearRegressionModel
25+
import org.apache.spark.mllib.regression.RidgeRegressionModel
26+
import org.apache.spark.mllib.util.LinearDataGenerator
27+
import org.scalatest.FunSuite
28+
import org.dmg.pmml.RegressionModel
29+
30+
class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
31+
32+
test("GeneralizedLinearPMMLModelExport generate PMML format") {
33+
34+
//arrange models to test
35+
val linearInput = LinearDataGenerator.generateLinearInput(
36+
3.0, Array(10.0, 10.0), 1, 17)
37+
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
38+
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
39+
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
40+
41+
//act by exporting the model to the PMML format
42+
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
43+
//assert that the PMML format is as expected
44+
assert(linearModelExport.isInstanceOf[PMMLModelExport])
45+
var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml()
46+
assert(pmml.getHeader().getDescription() === "linear regression")
47+
//check that the number of fields match the weights size
48+
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size)
49+
//this verify that there is a model attached to the pmml object and the model is a regression one
50+
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
51+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
52+
.getRegressionTables().get(0).getNumericPredictors().size() === linearRegressionModel.weights.size)
53+
54+
//act
55+
val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML)
56+
//assert that the PMML format is as expected
57+
assert(ridgeModelExport.isInstanceOf[PMMLModelExport])
58+
pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml()
59+
assert(pmml.getHeader().getDescription() === "ridge regression")
60+
//check that the number of fields match the weights size
61+
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size)
62+
//this verify that there is a model attached to the pmml object and the model is a regression one
63+
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
64+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
65+
.getRegressionTables().get(0).getNumericPredictors().size() === ridgeRegressionModel.weights.size)
66+
67+
//act
68+
val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML)
69+
//assert that the PMML format is as expected
70+
assert(lassoModelExport.isInstanceOf[PMMLModelExport])
71+
pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml()
72+
assert(pmml.getHeader().getDescription() === "lasso regression")
73+
//check that the number of fields match the weights size
74+
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size)
75+
//this verify that there is a model attached to the pmml object and the model is a regression one
76+
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
77+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
78+
.getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size)
79+
80+
//manual checking
81+
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
82+
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
83+
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
84+
85+
}
86+
87+
}

0 commit comments

Comments
 (0)