Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 226e184

Browse files
committed
added javadoc and export model type in case there is a need to support
other types of export (not just PMML)
1 parent a0e3679 commit 226e184

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import java.io.OutputStream
2121

2222
trait ModelExport {
2323

24+
/**
25+
* Write the exported model to the output stream specified
26+
*/
2427
def save(outputStream: OutputStream): Unit
2528

2629
}

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,22 @@ package org.apache.spark.mllib.export
1919

2020
import org.apache.spark.mllib.clustering.KMeansModel
2121
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
22+
import org.apache.spark.mllib.export.ModelExportType._
2223

2324
object ModelExportFactory {
2425

25-
//TODO: introduce model export typed
26-
27-
def createModelExport(model: Any): ModelExport = model match {
28-
case kmeans: KMeansModel => new KMeansPMMLModelExport
29-
case _ => throw new IllegalArgumentException("Export not supported for model " + model.getClass)
26+
/**
27+
* Factory object to help creating the necessary ModelExport implementation
28+
* taking as input the ModelExportType (for example PMML) and the machine learning model (for example KMeansModel).
29+
*/
30+
def createModelExport(model: Any, exportType: ModelExportType): ModelExport = {
31+
return exportType match{
32+
case PMML => model match{
33+
case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans)
34+
case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
35+
}
36+
case _ => throw new IllegalArgumentException("Export type not supported:" + exportType)
37+
}
3038
}
3139

3240
}

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.mllib.export
1919

20+
/**
21+
* Defines export types.
22+
* - PMML exports the machine learning models in an XML-based file format called Predictive Model Markup Language developed by the Data Mining Group (www.dmg.org).
23+
*/
2024
object ModelExportType extends Enumeration{
2125

2226
type ModelExportType = Value

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,19 @@
1717

1818
package org.apache.spark.mllib.export.pmml
1919

20-
class KMeansPMMLModelExport extends PMMLModelExport{
20+
import org.apache.spark.mllib.clustering.KMeansModel
2121

22-
populateKMeansPMML();
22+
/**
23+
* PMML Model Export for KMeansModel class
24+
*/
25+
class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
26+
27+
/**
28+
* Export the input KMeansModel model to PMML format
29+
*/
30+
populateKMeansPMML(model);
2331

24-
def populateKMeansPMML(): Unit = {
32+
private def populateKMeansPMML(model : KMeansModel): Unit = {
2533
//TODO: set here header description
2634
pmml.setVersion("testing... kmeans...");
2735
//TODO: generate the model...

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@ import scala.beans.BeanProperty
2626

2727
trait PMMLModelExport extends ModelExport{
2828

29+
/**
30+
* Holder of the exported model in PMML format
31+
*/
2932
@BeanProperty
3033
var pmml: PMML = new PMML();
3134
//TODO: set here header app copyright and timestamp
3235

36+
/**
37+
* Write the exported model (in PMML XML) to the output stream specified
38+
*/
39+
@Override
3340
def save(outputStream: OutputStream): Unit = {
3441
JAXBUtil.marshalPMML(pmml, new StreamResult(outputStream));
3542
}

0 commit comments

Comments
 (0)