Skip to content

Commit aaca19c

Browse files
selvinsourcejeanlyn
authored andcommitted
[SPARK-1406] Mllib pmml model export
See PDF attached to the JIRA issue 1406. The contribution is my original work and I license the work to the project under the project's open source license. Author: Vincenzo Selvaggio <[email protected]> Author: Xiangrui Meng <[email protected]> Author: selvinsource <[email protected]> Closes apache#3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits: 852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file 085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style 30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit 7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression 25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model 66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible a0a55f7 [Vincenzo Selvaggio] Merge pull request #2 from mengxr/SPARK-1406 3c22f79 [Xiangrui Meng] more code style e2313df [Vincenzo Selvaggio] Merge pull request #1 from mengxr/SPARK-1406 472d757 [Xiangrui Meng] fix code style 1676e15 [Vincenzo Selvaggio] fixed scala issue e2ffae8 [Vincenzo Selvaggio] fixed scala style b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context 7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models 7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories 03bc3a5 [Vincenzo Selvaggio] added logistic regression da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export 82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style 1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change 3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark a1b4dc3 [Vincenzo Selvaggio] updated imports 834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines 349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format c3ef9b8 [Vincenzo Selvaggio] set it to private 6357b98 [Vincenzo Selvaggio] set it to private e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object aba5ee1 [Vincenzo Selvaggio] fixed cluster export cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406 07a29bf [selvinsource] Update LICENSE 8841439 [Vincenzo Selvaggio] adjust scala style in order to compile 1433b11 [Vincenzo Selvaggio] complete suite tests 8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation 9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait 226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML) a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation
1 parent f787b18 commit aaca19c

18 files changed

+774
-6
lines changed

LICENSE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ BSD-style licenses
814814
The following components are provided under a BSD-style license. See project link for details.
815815

816816
(BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
817+
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model)
817818
(BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
818819
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
819820
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)

mllib/pom.xml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,21 @@
109109
<type>test-jar</type>
110110
<scope>test</scope>
111111
</dependency>
112+
<dependency>
113+
<groupId>org.jpmml</groupId>
114+
<artifactId>pmml-model</artifactId>
115+
<version>1.1.15</version>
116+
<exclusions>
117+
<exclusion>
118+
<groupId>com.sun.xml.fastinfoset</groupId>
119+
<artifactId>FastInfoset</artifactId>
120+
</exclusion>
121+
<exclusion>
122+
<groupId>com.sun.istack</groupId>
123+
<artifactId>istack-commons-runtime</artifactId>
124+
</exclusion>
125+
</exclusions>
126+
</dependency>
112127
</dependencies>
113128
<profiles>
114129
<profile>

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
2323
import org.apache.spark.mllib.linalg.BLAS.dot
2424
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
2525
import org.apache.spark.mllib.optimization._
26+
import org.apache.spark.mllib.pmml.PMMLExportable
2627
import org.apache.spark.mllib.regression._
2728
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
2829
import org.apache.spark.rdd.RDD
@@ -46,7 +47,7 @@ class LogisticRegressionModel (
4647
val numFeatures: Int,
4748
val numClasses: Int)
4849
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
49-
with Saveable {
50+
with Saveable with PMMLExportable {
5051

5152
if (numClasses == 2) {
5253
require(weights.size == numFeatures,

mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
2222
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
2323
import org.apache.spark.mllib.linalg.Vector
2424
import org.apache.spark.mllib.optimization._
25+
import org.apache.spark.mllib.pmml.PMMLExportable
2526
import org.apache.spark.mllib.regression._
2627
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
2728
import org.apache.spark.rdd.RDD
@@ -36,7 +37,7 @@ class SVMModel (
3637
override val weights: Vector,
3738
override val intercept: Double)
3839
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
39-
with Saveable {
40+
with Saveable with PMMLExportable {
4041

4142
private var threshold: Option[Double] = Some(0.0)
4243

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
2525

2626
import org.apache.spark.api.java.JavaRDD
2727
import org.apache.spark.mllib.linalg.Vector
28+
import org.apache.spark.mllib.pmml.PMMLExportable
2829
import org.apache.spark.mllib.util.{Loader, Saveable}
2930
import org.apache.spark.rdd.RDD
3031
import org.apache.spark.SparkContext
@@ -34,7 +35,8 @@ import org.apache.spark.sql.Row
3435
/**
3536
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
3637
*/
37-
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
38+
class KMeansModel (
39+
val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {
3840

3941
/** A Java-friendly constructor that takes an Iterable of Vectors. */
4042
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.pmml
19+
20+
import java.io.{File, OutputStream, StringWriter}
21+
import javax.xml.transform.stream.StreamResult
22+
23+
import org.jpmml.model.JAXBUtil
24+
25+
import org.apache.spark.SparkContext
26+
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
27+
28+
/**
29+
* Export model to the PMML format
30+
* Predictive Model Markup Language (PMML) is an XML-based file format
31+
* developed by the Data Mining Group (www.dmg.org).
32+
*/
33+
trait PMMLExportable {
34+
35+
/**
36+
* Export the model to the stream result in PMML format
37+
*/
38+
private def toPMML(streamResult: StreamResult): Unit = {
39+
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
40+
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
41+
}
42+
43+
/**
44+
* Export the model to a local file in PMML format
45+
*/
46+
def toPMML(localPath: String): Unit = {
47+
toPMML(new StreamResult(new File(localPath)))
48+
}
49+
50+
/**
51+
* Export the model to a directory on a distributed file system in PMML format
52+
*/
53+
def toPMML(sc: SparkContext, path: String): Unit = {
54+
val pmml = toPMML()
55+
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
56+
}
57+
58+
/**
59+
* Export the model to the OutputStream in PMML format
60+
*/
61+
def toPMML(outputStream: OutputStream): Unit = {
62+
toPMML(new StreamResult(outputStream))
63+
}
64+
65+
/**
66+
* Export the model to a String in PMML format
67+
*/
68+
def toPMML(): String = {
69+
val writer = new StringWriter
70+
toPMML(new StreamResult(writer))
71+
writer.toString
72+
}
73+
74+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.pmml.export
19+
20+
import scala.{Array => SArray}
21+
22+
import org.dmg.pmml._
23+
24+
import org.apache.spark.mllib.regression.GeneralizedLinearModel
25+
26+
/**
27+
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
28+
*/
29+
private[mllib] class BinaryClassificationPMMLModelExport(
30+
model : GeneralizedLinearModel,
31+
description : String,
32+
normalizationMethod : RegressionNormalizationMethodType,
33+
threshold: Double)
34+
extends PMMLModelExport {
35+
36+
populateBinaryClassificationPMML()
37+
38+
/**
39+
* Export the input LogisticRegressionModel or SVMModel to PMML format.
40+
*/
41+
private def populateBinaryClassificationPMML(): Unit = {
42+
pmml.getHeader.setDescription(description)
43+
44+
if (model.weights.size > 0) {
45+
val fields = new SArray[FieldName](model.weights.size)
46+
val dataDictionary = new DataDictionary
47+
val miningSchema = new MiningSchema
48+
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
49+
var interceptNO = threshold
50+
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
51+
if (threshold <= 0) {
52+
interceptNO = Double.MinValue
53+
} else if (threshold >= 1) {
54+
interceptNO = Double.MaxValue
55+
} else {
56+
interceptNO = -math.log(1 / threshold - 1)
57+
}
58+
}
59+
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
60+
val regressionModel = new RegressionModel()
61+
.withFunctionName(MiningFunctionType.CLASSIFICATION)
62+
.withMiningSchema(miningSchema)
63+
.withModelName(description)
64+
.withNormalizationMethod(normalizationMethod)
65+
.withRegressionTables(regressionTableYES, regressionTableNO)
66+
67+
for (i <- 0 until model.weights.size) {
68+
fields(i) = FieldName.create("field_" + i)
69+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
70+
miningSchema
71+
.withMiningFields(new MiningField(fields(i))
72+
.withUsageType(FieldUsageType.ACTIVE))
73+
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
74+
}
75+
76+
// add target field
77+
val targetField = FieldName.create("target")
78+
dataDictionary
79+
.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
80+
miningSchema
81+
.withMiningFields(new MiningField(targetField)
82+
.withUsageType(FieldUsageType.TARGET))
83+
84+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
85+
86+
pmml.setDataDictionary(dataDictionary)
87+
pmml.withModels(regressionModel)
88+
}
89+
}
90+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.pmml.export
19+
20+
import scala.{Array => SArray}
21+
22+
import org.dmg.pmml._
23+
24+
import org.apache.spark.mllib.regression.GeneralizedLinearModel
25+
26+
/**
27+
* PMML Model Export for GeneralizedLinearModel abstract class
28+
*/
29+
private[mllib] class GeneralizedLinearPMMLModelExport(
30+
model: GeneralizedLinearModel,
31+
description: String)
32+
extends PMMLModelExport {
33+
34+
populateGeneralizedLinearPMML(model)
35+
36+
/**
37+
* Export the input GeneralizedLinearModel model to PMML format.
38+
*/
39+
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
40+
pmml.getHeader.setDescription(description)
41+
42+
if (model.weights.size > 0) {
43+
val fields = new SArray[FieldName](model.weights.size)
44+
val dataDictionary = new DataDictionary
45+
val miningSchema = new MiningSchema
46+
val regressionTable = new RegressionTable(model.intercept)
47+
val regressionModel = new RegressionModel()
48+
.withFunctionName(MiningFunctionType.REGRESSION)
49+
.withMiningSchema(miningSchema)
50+
.withModelName(description)
51+
.withRegressionTables(regressionTable)
52+
53+
for (i <- 0 until model.weights.size) {
54+
fields(i) = FieldName.create("field_" + i)
55+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
56+
miningSchema
57+
.withMiningFields(new MiningField(fields(i))
58+
.withUsageType(FieldUsageType.ACTIVE))
59+
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
60+
}
61+
62+
// for completeness add target field
63+
val targetField = FieldName.create("target")
64+
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
65+
miningSchema
66+
.withMiningFields(new MiningField(targetField)
67+
.withUsageType(FieldUsageType.TARGET))
68+
69+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
70+
71+
pmml.setDataDictionary(dataDictionary)
72+
pmml.withModels(regressionModel)
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)