Skip to content

Commit 8e71b8d

Browse files
committed
kmeans pmml export implementation
1 parent 9bc494f commit 8e71b8d

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@
1818
package org.apache.spark.mllib.export.pmml
1919

2020
import org.apache.spark.mllib.clustering.KMeansModel
21+
import org.dmg.pmml.DataDictionary
22+
import org.dmg.pmml.FieldName
23+
import org.dmg.pmml.DataField
24+
import org.dmg.pmml.OpType
25+
import org.dmg.pmml.DataType
26+
import org.dmg.pmml.MiningSchema
27+
import org.dmg.pmml.MiningField
28+
import org.dmg.pmml.FieldUsageType
29+
import org.dmg.pmml.ComparisonMeasure
30+
import org.dmg.pmml.ComparisonMeasure.Kind
31+
import org.dmg.pmml.SquaredEuclidean
32+
import org.dmg.pmml.ClusteringModel
33+
import org.dmg.pmml.MiningFunctionType
34+
import org.dmg.pmml.ClusteringModel.ModelClass
35+
import org.dmg.pmml.ClusteringField
36+
import org.dmg.pmml.CompareFunctionType
37+
import org.dmg.pmml.Cluster
38+
import org.dmg.pmml.Array.Type
2139

2240
/**
2341
* PMML Model Export for KMeansModel class
@@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
3048
populateKMeansPMML(model);
3149

3250
private def populateKMeansPMML(model : KMeansModel): Unit = {
33-
//TODO: set here header description
34-
pmml.setVersion("testing... kmeans...");
35-
//TODO: generate the model...
51+
52+
pmml.getHeader().setDescription("k-means clustering");
53+
54+
if(model.clusterCenters.length > 0){
55+
56+
val clusterCenter = model.clusterCenters(0)
57+
58+
var fields = new Array[FieldName](clusterCenter.size)
59+
60+
var dataDictionary = new DataDictionary()
61+
62+
var miningSchema = new MiningSchema()
63+
64+
for ( i <- 0 to (clusterCenter.size - 1)) {
65+
fields(i) = FieldName.create("field_"+i)
66+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
67+
miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE))
68+
}
69+
70+
var comparisonMeasure = new ComparisonMeasure()
71+
.withKind(Kind.DISTANCE)
72+
.withMeasure(new SquaredEuclidean()
73+
);
74+
75+
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size());
76+
77+
pmml.setDataDictionary(dataDictionary);
78+
79+
var clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
80+
.withModelName("k-means");
81+
82+
for ( i <- 0 to (clusterCenter.size - 1)) {
83+
clusteringModel.withClusteringFields(new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
84+
var cluster = new Cluster().withName("cluster_"+i).withArray(new org.dmg.pmml.Array().withType(Type.REAL).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" ")))
85+
//cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue)
86+
clusteringModel.withClusters(cluster)
87+
}
88+
89+
pmml.withModels(clusteringModel);
90+
91+
}
92+
3693
}
3794

3895
}

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ import org.jpmml.model.JAXBUtil
2323
import org.dmg.pmml.PMML
2424
import javax.xml.transform.stream.StreamResult
2525
import scala.beans.BeanProperty
26+
import org.dmg.pmml.Application
27+
import org.dmg.pmml.Timestamp
28+
import org.dmg.pmml.Header
29+
import java.text.SimpleDateFormat
30+
import java.util.Date
2631

2732
trait PMMLModelExport extends ModelExport{
2833

@@ -31,7 +36,19 @@ trait PMMLModelExport extends ModelExport{
3136
*/
3237
@BeanProperty
3338
var pmml: PMML = new PMML();
34-
//TODO: set here header app copyright and timestamp
39+
40+
setHeader(pmml);
41+
42+
private def setHeader(pmml : PMML): Unit = {
43+
var version = getClass().getPackage().getImplementationVersion()
44+
var app = new Application().withName("Apache Spark MLlib").withVersion(version)
45+
var timestamp = new Timestamp().withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date()))
46+
var header = new Header()
47+
.withCopyright("www.dmg.org")
48+
.withApplication(app)
49+
.withTimestamp(timestamp);
50+
pmml.setHeader(header);
51+
}
3552

3653
/**
3754
* Write the exported model (in PMML XML) to the output stream specified

mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ModelExportFactorySuite extends FunSuite{
4040

4141
}
4242

43-
test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") {
43+
test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") {
4444

4545
val invalidModel = new Object;
4646

0 commit comments

Comments
 (0)