Skip to content

Commit 2466322

Browse files
committed
refactor Bucketizer
1 parent 11fb00a commit 2466322

File tree

2 files changed

+88
-51
lines changed

2 files changed

+88
-51
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,46 +25,55 @@ import org.apache.spark.ml.util.SchemaUtils
2525
import org.apache.spark.ml.{Estimator, Model}
2626
import org.apache.spark.sql._
2727
import org.apache.spark.sql.functions._
28-
import org.apache.spark.sql.types.{DoubleType, StructType}
28+
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
2929

3030
/**
3131
* :: AlphaComponent ::
3232
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
3333
*/
3434
@AlphaComponent
35-
final class Bucketizer(override val parent: Estimator[Bucketizer] = null)
35+
private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
3636
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
3737

38-
/**
39-
* The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
40-
* way.
41-
*/
42-
private def checkBuckets(buckets: Array[Double]): Boolean = {
43-
if (buckets.size == 0) false
44-
else if (buckets.size == 1) true
45-
else {
46-
buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
47-
if (validator & prevValue <= currValue) {
48-
(true, currValue)
49-
} else {
50-
(false, currValue)
51-
}
52-
}._1
53-
}
54-
}
38+
def this() = this(null)
5539

5640
/**
57-
* Parameter for mapping continuous features into buckets.
41+
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42+
* A bucket defined by splits x,y holds values in the range (x,y].
5843
* @group param
5944
*/
60-
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
61-
"Split points for mapping continuous features into buckets.", checkBuckets)
45+
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
46+
"Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47+
" buckets. A bucket defined by splits x,y holds values in the range (x,y].",
48+
Bucketizer.checkSplits)
6249

6350
/** @group getParam */
64-
def getBuckets: Array[Double] = $(buckets)
51+
def getSplits: Array[Double] = $(splits)
6552

6653
/** @group setParam */
67-
def setBuckets(value: Array[Double]): this.type = set(buckets, value)
54+
def setSplits(value: Array[Double]): this.type = set(splits, value)
55+
56+
/** @group Param */
57+
val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive",
58+
"An indicator of the inclusiveness of negative infinite.")
59+
setDefault(lowerInclusive -> true)
60+
61+
/** @group getParam */
62+
def getLowerInclusive: Boolean = $(lowerInclusive)
63+
64+
/** @group setParam */
65+
def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value)
66+
67+
/** @group Param */
68+
val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive",
69+
"An indicator of the inclusiveness of positive infinite.")
70+
setDefault(upperInclusive -> true)
71+
72+
/** @group getParam */
73+
def getUpperInclusive: Boolean = $(upperInclusive)
74+
75+
/** @group setParam */
76+
def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value)
6877

6978
/** @group setParam */
7079
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -74,45 +83,68 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null)
7483

7584
override def transform(dataset: DataFrame): DataFrame = {
7685
transformSchema(dataset.schema)
77-
val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) }
78-
val outputColName = $(outputCol)
79-
val metadata = NominalAttribute.defaultAttr
80-
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
81-
dataset.select(col("*"), bucketizer(dataset($(inputCol))).as(outputColName, metadata))
86+
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
87+
val bucketizer = udf { feature: Double =>
88+
Bucketizer.binarySearchForBuckets(wrappedSplits, feature) }
89+
val newCol = bucketizer(dataset($(inputCol)))
90+
val newField = prepOutputField(dataset.schema)
91+
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
92+
}
93+
94+
private def prepOutputField(schema: StructType): StructField = {
95+
val attr = new NominalAttribute(
96+
name = Some($(outputCol)),
97+
isOrdinal = Some(true),
98+
numValues = Some($(splits).size),
99+
values = Some($(splits).map(_.toString)))
100+
101+
attr.toStructField()
102+
}
103+
104+
override def transformSchema(schema: StructType): StructType = {
105+
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
106+
require(schema.fields.forall(_.name != $(outputCol)),
107+
s"Output column ${$(outputCol)} already exists.")
108+
StructType(schema.fields :+ prepOutputField(schema))
109+
}
110+
}
111+
112+
object Bucketizer {
113+
/**
114+
* The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
115+
* increasing way.
116+
*/
117+
private def checkSplits(splits: Array[Double]): Boolean = {
118+
if (splits.size == 0) false
119+
else if (splits.size == 1) true
120+
else {
121+
splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
122+
if (validator && prevValue < currValue) {
123+
(true, currValue)
124+
} else {
125+
(false, currValue)
126+
}
127+
}._1
128+
}
82129
}
83130

84131
/**
85132
* Binary searching in several buckets to place each data point.
86133
*/
87-
private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
88-
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
134+
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
89135
var left = 0
90-
var right = wrappedSplits.length - 2
136+
var right = splits.length - 2
91137
while (left <= right) {
92138
val mid = left + (right - left) / 2
93-
val split = wrappedSplits(mid)
94-
if ((feature > split) && (feature <= wrappedSplits(mid + 1))) {
139+
val split = splits(mid)
140+
if ((feature > split) && (feature <= splits(mid + 1))) {
95141
return mid
96142
} else if (feature <= split) {
97143
right = mid - 1
98144
} else {
99145
left = mid + 1
100146
}
101147
}
102-
-1
103-
}
104-
105-
override def transformSchema(schema: StructType): StructType = {
106-
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
107-
108-
val inputFields = schema.fields
109-
val outputColName = $(outputCol)
110-
111-
require(inputFields.forall(_.name != outputColName),
112-
s"Output column $outputColName already exists.")
113-
114-
val attr = NominalAttribute.defaultAttr.withName(outputColName)
115-
val outputFields = inputFields :+ attr.toStructField()
116-
StructType(outputFields)
148+
throw new Exception("Failed to find a bucket.")
117149
}
118150
}

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.scalatest.FunSuite
21+
2022
import org.apache.spark.mllib.util.MLlibTestSparkContext
2123
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
22-
import org.scalatest.FunSuite
2324

2425
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
2526

@@ -34,11 +35,15 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
3435
val bucketizer: Bucketizer = new Bucketizer()
3536
.setInputCol("feature")
3637
.setOutputCol("result")
37-
.setBuckets(buckets)
38+
.setSplits(buckets)
3839

3940
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
4041
case Row(x: Double, y: Double) =>
4142
assert(x === y, "The feature value is not correct after bucketing.")
4243
}
4344
}
45+
46+
test("Binary search for finding buckets") {
47+
48+
}
4449
}

0 commit comments

Comments
 (0)