@@ -22,29 +22,38 @@ import org.apache.spark.mllib.linalg.Vectors
22
22
import org .apache .spark .mllib .export .ModelExportFactory
23
23
import org .apache .spark .mllib .clustering .KMeansModel
24
24
import org .apache .spark .mllib .export .ModelExportType
25
+ import org .dmg .pmml .ClusteringModel
26
+ import javax .xml .parsers .DocumentBuilderFactory
27
+ import java .io .ByteArrayOutputStream
25
28
26
29
class KMeansPMMLModelExportSuite extends FunSuite {
27
30
28
31
test(" KMeansPMMLModelExport generate PMML format" ) {
29
32
33
+ // arrange model to test
30
34
val clusterCenters = Array (
31
35
Vectors .dense(1.0 , 2.0 , 6.0 ),
32
36
Vectors .dense(1.0 , 3.0 , 0.0 ),
33
37
Vectors .dense(1.0 , 4.0 , 6.0 )
34
38
)
35
-
36
39
val kmeansModel = new KMeansModel (clusterCenters);
37
40
41
+ // act by exporting the model to the PMML format
38
42
val modelExport = ModelExportFactory .createModelExport(kmeansModel, ModelExportType .PMML )
39
-
43
+
44
+ // assert that the PMML format is as expected
40
45
assert(modelExport.isInstanceOf [PMMLModelExport ])
46
+ var pmml = modelExport.asInstanceOf [PMMLModelExport ].getPmml()
47
+ assert(pmml.getHeader().getDescription() === " k-means clustering" )
48
+ // check that the number of fields match the single vector size
49
+ assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0 ).size)
50
+ // this verify that there is a model attached to the pmml object and the model is a clustering one
51
+ // it also verifies that the pmml model has the same number of clusters of the spark model
52
+ assert(pmml.getModels().get(0 ).asInstanceOf [ClusteringModel ].getNumberOfClusters() === clusterCenters.size)
41
53
42
- // TODO: asserts
43
- // compare pmml fields to strings
44
- modelExport.asInstanceOf [PMMLModelExport ].getPmml()
45
- // use document builder to load the xml generated and validated the notes by looking for them
46
- modelExport.asInstanceOf [PMMLModelExport ].save(System .out)
47
- // saveLocalFile too??? search how to unit test file creating in java
54
+ // manual checking
55
+ // modelExport.asInstanceOf[PMMLModelExport].save(System.out)
56
+ // modelExport.asInstanceOf[PMMLModelExport].saveLocalFile("/tmp/kmeans.xml")
48
57
49
58
}
50
59
0 commit comments