Skip to content

Commit 1433b11

Browse files
committed
complete suite tests
1 parent 8e71b8d commit 1433b11

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,30 @@ class ModelExportFactorySuite extends FunSuite{
2626

2727
test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
2828

29+
//arrange
2930
val clusterCenters = Array(
3031
Vectors.dense(1.0, 2.0, 6.0),
3132
Vectors.dense(1.0, 3.0, 0.0),
3233
Vectors.dense(1.0, 4.0, 6.0)
3334
)
34-
3535
val kmeansModel = new KMeansModel(clusterCenters);
3636

37+
//act
3738
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)
3839

40+
//assert
3941
assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
4042

4143
}
4244

43-
test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") {
45+
test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
4446

47+
//arrange
4548
val invalidModel = new Object;
4649

50+
//assert
4751
intercept[IllegalArgumentException] {
52+
//act
4853
ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML)
4954
}
5055

mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,38 @@ import org.apache.spark.mllib.linalg.Vectors
2222
import org.apache.spark.mllib.export.ModelExportFactory
2323
import org.apache.spark.mllib.clustering.KMeansModel
2424
import org.apache.spark.mllib.export.ModelExportType
25+
import org.dmg.pmml.ClusteringModel
26+
import javax.xml.parsers.DocumentBuilderFactory
27+
import java.io.ByteArrayOutputStream
2528

2629
class KMeansPMMLModelExportSuite extends FunSuite{
2730

2831
test("KMeansPMMLModelExport generate PMML format") {
2932

33+
//arrange model to test
3034
val clusterCenters = Array(
3135
Vectors.dense(1.0, 2.0, 6.0),
3236
Vectors.dense(1.0, 3.0, 0.0),
3337
Vectors.dense(1.0, 4.0, 6.0)
3438
)
35-
3639
val kmeansModel = new KMeansModel(clusterCenters);
3740

41+
//act by exporting the model to the PMML format
3842
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)
39-
43+
44+
//assert that the PMML format is as expected
4045
assert(modelExport.isInstanceOf[PMMLModelExport])
46+
var pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml()
47+
assert(pmml.getHeader().getDescription() === "k-means clustering")
48+
//check that the number of fields match the single vector size
49+
assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0).size)
50+
//this verify that there is a model attached to the pmml object and the model is a clustering one
51+
//it also verifies that the pmml model has the same number of clusters of the spark model
52+
assert(pmml.getModels().get(0).asInstanceOf[ClusteringModel].getNumberOfClusters() === clusterCenters.size)
4153

42-
//TODO: asserts
43-
//compare pmml fields to strings
44-
modelExport.asInstanceOf[PMMLModelExport].getPmml()
45-
//use document builder to load the xml generated and validated the notes by looking for them
46-
modelExport.asInstanceOf[PMMLModelExport].save(System.out)
47-
//saveLocalFile too??? search how to unit test file creating in java
54+
//manual checking
55+
//modelExport.asInstanceOf[PMMLModelExport].save(System.out)
56+
//modelExport.asInstanceOf[PMMLModelExport].saveLocalFile("/tmp/kmeans.xml")
4857

4958
}
5059

0 commit comments

Comments
 (0)