18
18
package org .apache .spark .mllib .export .pmml
19
19
20
20
import org .apache .spark .mllib .clustering .KMeansModel
21
+ import org .dmg .pmml .DataDictionary
22
+ import org .dmg .pmml .FieldName
23
+ import org .dmg .pmml .DataField
24
+ import org .dmg .pmml .OpType
25
+ import org .dmg .pmml .DataType
26
+ import org .dmg .pmml .MiningSchema
27
+ import org .dmg .pmml .MiningField
28
+ import org .dmg .pmml .FieldUsageType
29
+ import org .dmg .pmml .ComparisonMeasure
30
+ import org .dmg .pmml .ComparisonMeasure .Kind
31
+ import org .dmg .pmml .SquaredEuclidean
32
+ import org .dmg .pmml .ClusteringModel
33
+ import org .dmg .pmml .MiningFunctionType
34
+ import org .dmg .pmml .ClusteringModel .ModelClass
35
+ import org .dmg .pmml .ClusteringField
36
+ import org .dmg .pmml .CompareFunctionType
37
+ import org .dmg .pmml .Cluster
38
+ import org .dmg .pmml .Array .Type
21
39
22
40
/**
23
41
* PMML Model Export for KMeansModel class
@@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
30
48
populateKMeansPMML(model);
31
49
32
50
private def populateKMeansPMML (model : KMeansModel ): Unit = {
33
- // TODO: set here header description
34
- pmml.setVersion(" testing... kmeans..." );
35
- // TODO: generate the model...
51
+
52
+ pmml.getHeader().setDescription(" k-means clustering" );
53
+
54
+ if (model.clusterCenters.length > 0 ){
55
+
56
+ val clusterCenter = model.clusterCenters(0 )
57
+
58
+ var fields = new Array [FieldName ](clusterCenter.size)
59
+
60
+ var dataDictionary = new DataDictionary ()
61
+
62
+ var miningSchema = new MiningSchema ()
63
+
64
+ for ( i <- 0 to (clusterCenter.size - 1 )) {
65
+ fields(i) = FieldName .create(" field_" + i)
66
+ dataDictionary.withDataFields(new DataField (fields(i), OpType .CONTINUOUS , DataType .DOUBLE ))
67
+ miningSchema.withMiningFields(new MiningField (fields(i)).withUsageType(FieldUsageType .ACTIVE ))
68
+ }
69
+
70
+ var comparisonMeasure = new ComparisonMeasure ()
71
+ .withKind(Kind .DISTANCE )
72
+ .withMeasure(new SquaredEuclidean ()
73
+ );
74
+
75
+ dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size());
76
+
77
+ pmml.setDataDictionary(dataDictionary);
78
+
79
+ var clusteringModel = new ClusteringModel (miningSchema, comparisonMeasure, MiningFunctionType .CLUSTERING , ModelClass .CENTER_BASED , model.clusterCenters.length)
80
+ .withModelName(" k-means" );
81
+
82
+ for ( i <- 0 to (clusterCenter.size - 1 )) {
83
+ clusteringModel.withClusteringFields(new ClusteringField (fields(i)).withCompareFunction(CompareFunctionType .ABS_DIFF ))
84
+ var cluster = new Cluster ().withName(" cluster_" + i).withArray(new org.dmg.pmml.Array ().withType(Type .REAL ).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" " )))
85
+ // cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue)
86
+ clusteringModel.withClusters(cluster)
87
+ }
88
+
89
+ pmml.withModels(clusteringModel);
90
+
91
+ }
92
+
36
93
}
37
94
38
95
}
0 commit comments