Skip to content

Commit 25dce33

Browse files
committed
[SPARK-1406] Update code to latest pmml model
1 parent dea98ca commit 25dce33

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
4444
val dataDictionary = new DataDictionary
4545
val miningSchema = new MiningSchema
4646
val regressionTable = new RegressionTable(model.intercept)
47-
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION)
47+
val regressionModel = new RegressionModel()
48+
.withFunctionName(MiningFunctionType.REGRESSION)
49+
.withMiningSchema(miningSchema)
4850
.withModelName(description)
4951
.withRegressionTables(regressionTable)
5052

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode
4444
val comparisonMeasure = new ComparisonMeasure()
4545
.withKind(ComparisonMeasure.Kind.DISTANCE)
4646
.withMeasure(new SquaredEuclidean())
47-
val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
48-
MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED,
49-
model.clusterCenters.length)
47+
val clusteringModel = new ClusteringModel()
5048
.withModelName("k-means")
49+
.withMiningSchema(miningSchema)
50+
.withComparisonMeasure(comparisonMeasure)
51+
.withFunctionName(MiningFunctionType.CLUSTERING)
52+
.withModelClass(ClusteringModel.ModelClass.CENTER_BASED)
53+
.withNumberOfClusters(model.clusterCenters.length)
5154

5255
for (i <- 0 until clusterCenter.size) {
5356
fields(i) = FieldName.create("field_" + i)

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ private[mllib] class LogisticRegressionPMMLModelExport(
4545
val miningSchema = new MiningSchema
4646
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
4747
val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0")
48-
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.CLASSIFICATION)
48+
val regressionModel = new RegressionModel()
49+
.withFunctionName(MiningFunctionType.CLASSIFICATION)
50+
.withMiningSchema(miningSchema)
4951
.withModelName(description)
5052
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
5153
.withRegressionTables(regressionTableYES, regressionTableNO)

0 commit comments

Comments
 (0)