Skip to content

Commit 9bc494f

Browse files
committed
added scala suite tests
added saveLocalFile to ModelExport trait
1 parent 226e184 commit 9bc494f

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,21 @@
1818
package org.apache.spark.mllib.export
1919

2020
import java.io.OutputStream
21+
import java.io.FileOutputStream
22+
import java.io.File
2123

2224
trait ModelExport {
2325

2426
/**
2527
* Write the exported model to the output stream specified
2628
*/
2729
def save(outputStream: OutputStream): Unit
30+
31+
/**
32+
* Write the exported model to the local file specified
33+
*/
34+
def saveLocalFile(path: String): Unit = {
35+
save(new FileOutputStream(new File(path)));
36+
}
2837

2938
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export
19+
20+
import org.apache.spark.mllib.clustering.KMeansModel
21+
import org.apache.spark.mllib.linalg.Vectors
22+
import org.scalatest.FunSuite
23+
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
24+
25+
class ModelExportFactorySuite extends FunSuite{
26+
27+
test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
28+
29+
val clusterCenters = Array(
30+
Vectors.dense(1.0, 2.0, 6.0),
31+
Vectors.dense(1.0, 3.0, 0.0),
32+
Vectors.dense(1.0, 4.0, 6.0)
33+
)
34+
35+
val kmeansModel = new KMeansModel(clusterCenters);
36+
37+
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)
38+
39+
assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
40+
41+
}
42+
43+
test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") {
44+
45+
val invalidModel = new Object;
46+
47+
intercept[IllegalArgumentException] {
48+
ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML)
49+
}
50+
51+
}
52+
53+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.export.pmml
19+
20+
import org.scalatest.FunSuite
21+
import org.apache.spark.mllib.linalg.Vectors
22+
import org.apache.spark.mllib.export.ModelExportFactory
23+
import org.apache.spark.mllib.clustering.KMeansModel
24+
import org.apache.spark.mllib.export.ModelExportType
25+
26+
class KMeansPMMLModelExportSuite extends FunSuite{
27+
28+
test("KMeansPMMLModelExport generate PMML format") {
29+
30+
val clusterCenters = Array(
31+
Vectors.dense(1.0, 2.0, 6.0),
32+
Vectors.dense(1.0, 3.0, 0.0),
33+
Vectors.dense(1.0, 4.0, 6.0)
34+
)
35+
36+
val kmeansModel = new KMeansModel(clusterCenters);
37+
38+
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)
39+
40+
assert(modelExport.isInstanceOf[PMMLModelExport])
41+
42+
//TODO: asserts
43+
//compare pmml fields to strings
44+
modelExport.asInstanceOf[PMMLModelExport].getPmml()
45+
//use document builder to load the xml generated and validated the notes by looking for them
46+
modelExport.asInstanceOf[PMMLModelExport].save(System.out)
47+
//saveLocalFile too??? search how to unit test file creating in java
48+
49+
}
50+
51+
}

0 commit comments

Comments
 (0)