@@ -25,46 +25,55 @@ import org.apache.spark.ml.util.SchemaUtils
25
25
import org .apache .spark .ml .{Estimator , Model }
26
26
import org .apache .spark .sql ._
27
27
import org .apache .spark .sql .functions ._
28
- import org .apache .spark .sql .types .{DoubleType , StructType }
28
+ import org .apache .spark .sql .types .{DoubleType , StructField , StructType }
29
29
30
30
/**
31
31
* :: AlphaComponent ::
32
32
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
33
33
*/
34
34
@ AlphaComponent
35
- final class Bucketizer (override val parent : Estimator [Bucketizer ] = null )
35
+ private [ml] final class Bucketizer (override val parent : Estimator [Bucketizer ])
36
36
extends Model [Bucketizer ] with HasInputCol with HasOutputCol {
37
37
38
- /**
39
- * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
40
- * way.
41
- */
42
- private def checkBuckets (buckets : Array [Double ]): Boolean = {
43
- if (buckets.size == 0 ) false
44
- else if (buckets.size == 1 ) true
45
- else {
46
- buckets.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
47
- if (validator & prevValue <= currValue) {
48
- (true , currValue)
49
- } else {
50
- (false , currValue)
51
- }
52
- }._1
53
- }
54
- }
38
+ def this () = this (null )
55
39
56
40
/**
57
- * Parameter for mapping continuous features into buckets.
41
+ * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42
+ * A bucket defined by splits x,y holds values in the range (x,y].
58
43
* @group param
59
44
*/
60
- val buckets : Param [Array [Double ]] = new Param [Array [Double ]](this , " buckets" ,
61
- " Split points for mapping continuous features into buckets." , checkBuckets)
45
+ val splits : Param [Array [Double ]] = new Param [Array [Double ]](this , " splits" ,
46
+ " Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47
+ " buckets. A bucket defined by splits x,y holds values in the range (x,y]." ,
48
+ Bucketizer .checkSplits)
62
49
63
50
/** @group getParam */
64
- def getBuckets : Array [Double ] = $(buckets )
51
+ def getSplits : Array [Double ] = $(splits )
65
52
66
53
/** @group setParam */
67
- def setBuckets (value : Array [Double ]): this .type = set(buckets, value)
54
+ def setSplits (value : Array [Double ]): this .type = set(splits, value)
55
+
56
+ /** @group Param */
57
+ val lowerInclusive : BooleanParam = new BooleanParam (this , " lowerInclusive" ,
58
+ " An indicator of the inclusiveness of negative infinite." )
59
+ setDefault(lowerInclusive -> true )
60
+
61
+ /** @group getParam */
62
+ def getLowerInclusive : Boolean = $(lowerInclusive)
63
+
64
+ /** @group setParam */
65
+ def setLowerInclusive (value : Boolean ): this .type = set(lowerInclusive, value)
66
+
67
+ /** @group Param */
68
+ val upperInclusive : BooleanParam = new BooleanParam (this , " upperInclusive" ,
69
+ " An indicator of the inclusiveness of positive infinite." )
70
+ setDefault(upperInclusive -> true )
71
+
72
+ /** @group getParam */
73
+ def getUpperInclusive : Boolean = $(upperInclusive)
74
+
75
+ /** @group setParam */
76
+ def setUpperInclusive (value : Boolean ): this .type = set(upperInclusive, value)
68
77
69
78
/** @group setParam */
70
79
def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -74,45 +83,68 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null)
74
83
75
84
override def transform (dataset : DataFrame ): DataFrame = {
76
85
transformSchema(dataset.schema)
77
- val bucketizer = udf { feature : Double => binarySearchForBuckets($(buckets), feature) }
78
- val outputColName = $(outputCol)
79
- val metadata = NominalAttribute .defaultAttr
80
- .withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
81
- dataset.select(col(" *" ), bucketizer(dataset($(inputCol))).as(outputColName, metadata))
86
+ val wrappedSplits = Array (Double .MinValue ) ++ $(splits) ++ Array (Double .MaxValue )
87
+ val bucketizer = udf { feature : Double =>
88
+ Bucketizer .binarySearchForBuckets(wrappedSplits, feature) }
89
+ val newCol = bucketizer(dataset($(inputCol)))
90
+ val newField = prepOutputField(dataset.schema)
91
+ dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
92
+ }
93
+
94
+ private def prepOutputField (schema : StructType ): StructField = {
95
+ val attr = new NominalAttribute (
96
+ name = Some ($(outputCol)),
97
+ isOrdinal = Some (true ),
98
+ numValues = Some ($(splits).size),
99
+ values = Some ($(splits).map(_.toString)))
100
+
101
+ attr.toStructField()
102
+ }
103
+
104
+ override def transformSchema (schema : StructType ): StructType = {
105
+ SchemaUtils .checkColumnType(schema, $(inputCol), DoubleType )
106
+ require(schema.fields.forall(_.name != $(outputCol)),
107
+ s " Output column ${$(outputCol)} already exists. " )
108
+ StructType (schema.fields :+ prepOutputField(schema))
109
+ }
110
+ }
111
+
112
+ object Bucketizer {
113
+ /**
114
+ * The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
115
+ * increasing way.
116
+ */
117
+ private def checkSplits (splits : Array [Double ]): Boolean = {
118
+ if (splits.size == 0 ) false
119
+ else if (splits.size == 1 ) true
120
+ else {
121
+ splits.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
122
+ if (validator && prevValue < currValue) {
123
+ (true , currValue)
124
+ } else {
125
+ (false , currValue)
126
+ }
127
+ }._1
128
+ }
82
129
}
83
130
84
131
/**
85
132
* Binary searching in several buckets to place each data point.
86
133
*/
87
- private def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
88
- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
134
+ private [feature] def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
89
135
var left = 0
90
- var right = wrappedSplits .length - 2
136
+ var right = splits .length - 2
91
137
while (left <= right) {
92
138
val mid = left + (right - left) / 2
93
- val split = wrappedSplits (mid)
94
- if ((feature > split) && (feature <= wrappedSplits (mid + 1 ))) {
139
+ val split = splits (mid)
140
+ if ((feature > split) && (feature <= splits (mid + 1 ))) {
95
141
return mid
96
142
} else if (feature <= split) {
97
143
right = mid - 1
98
144
} else {
99
145
left = mid + 1
100
146
}
101
147
}
102
- - 1
103
- }
104
-
105
- override def transformSchema (schema : StructType ): StructType = {
106
- SchemaUtils .checkColumnType(schema, $(inputCol), DoubleType )
107
-
108
- val inputFields = schema.fields
109
- val outputColName = $(outputCol)
110
-
111
- require(inputFields.forall(_.name != outputColName),
112
- s " Output column $outputColName already exists. " )
113
-
114
- val attr = NominalAttribute .defaultAttr.withName(outputColName)
115
- val outputFields = inputFields :+ attr.toStructField()
116
- StructType (outputFields)
148
+ throw new Exception (" Failed to find a bucket." )
117
149
}
118
150
}
0 commit comments