Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 6ee70de

Browse files
committed
drop tmp col from OneVsRest output
1 parent ad06727 commit 6ee70de

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
131131
// output label and label metadata as prediction
132132
val labelUdf = callUDF(label, DoubleType, col(accColName))
133133
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
134+
.drop(accColName)
134135
}
135136
}
136137

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.ml.classification
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.scalatest.FunSuite
21+
2122
import org.apache.spark.ml.attribute.NominalAttribute
2223
import org.apache.spark.ml.util.MetadataUtils
2324
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -29,7 +30,7 @@ import org.apache.spark.mllib.util.TestingUtils._
2930
import org.apache.spark.rdd.RDD
3031
import org.apache.spark.sql.DataFrame
3132

32-
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
33+
class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
3334

3435
@transient var dataset: DataFrame = _
3536
@transient var rdd: RDD[LabeledPoint] = _
@@ -93,6 +94,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
9394
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
9495
ova.fit(datasetWithLabelMetadata)
9596
}
97+
98+
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
99+
val logReg = new LogisticRegression()
100+
.setMaxIter(1)
101+
val ovr = new OneVsRest()
102+
.setClassifier(logReg)
103+
val output = ovr.fit(dataset).transform(dataset)
104+
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
105+
}
96106
}
97107

98108
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {

0 commit comments

Comments
 (0)