Skip to content

Commit 758ebf7

Browse files
committed
SPARK-6480 [CORE] histogram() bucket function is wrong in some simple edge cases
Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly Author: Sean Owen <[email protected]> Closes #5148 from srowen/SPARK-6480 and squashes the following commits: 974a0a0 [Sean Owen] Additional test of huge ranges, and a few more comments (and comment fixes) 23ec01e [Sean Owen] Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly (cherry picked from commit fe15ea9) Signed-off-by: Sean Owen <[email protected]>
1 parent 61c059a commit 758ebf7

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,25 +192,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
192192
}
193193
}
194194
// Determine the bucket function in constant time. Requires that buckets are evenly spaced
195-
def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
195+
def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = {
196196
// If our input is not a number unless the increment is also NaN then we fail fast
197-
if (e.isNaN()) {
198-
return None
199-
}
200-
val bucketNumber = (e - min)/(increment)
201-
// We do this rather than buckets.lengthCompare(bucketNumber)
202-
// because Array[Double] fails to override it (for now).
203-
if (bucketNumber > count || bucketNumber < 0) {
197+
if (e.isNaN || e < min || e > max) {
204198
None
205199
} else {
206-
Some(bucketNumber.toInt.min(count - 1))
200+
// Compute ratio of e's distance along range to total range first, for better precision
201+
val bucketNumber = (((e - min) / (max - min)) * count).toInt
202+
// should be less than count, but will equal count if e == max, in which case
203+
// it's part of the last end-range-inclusive bucket, so return count-1
204+
Some(math.min(bucketNumber, count - 1))
207205
}
208206
}
209207
// Decide which bucket function to pass to histogramPartition. We decide here
210-
// rather than having a general function so that the decission need only be made
208+
// rather than having a general function so that the decision need only be made
211209
// once rather than once per shard
212210
val bucketFunction = if (evenBuckets) {
213-
fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
211+
fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _
214212
} else {
215213
basicBucketFunction _
216214
}

core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
233233
assert(histogramBuckets === expectedHistogramBuckets)
234234
}
235235

236+
test("WorksWithDoubleValuesAtMinMax") {
237+
val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3))
238+
assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2)
239+
assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2)
240+
}
241+
236242
test("WorksWithoutBucketsWithMoreRequestedThanElements") {
237243
// Verify the basic case of one bucket and all elements in that bucket works
238244
val rdd = sc.parallelize(Seq(1, 2))
@@ -246,7 +252,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
246252
}
247253

248254
test("WorksWithoutBucketsForLargerDatasets") {
249-
// Verify the case of slighly larger datasets
255+
// Verify the case of slightly larger datasets
250256
val rdd = sc.parallelize(6 to 99)
251257
val (histogramBuckets, histogramResults) = rdd.histogram(8)
252258
val expectedHistogramResults =
@@ -257,17 +263,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
257263
assert(histogramBuckets === expectedHistogramBuckets)
258264
}
259265

260-
test("WorksWithoutBucketsWithIrrationalBucketEdges") {
261-
// Verify the case of buckets with irrational edges. See #SPARK-2862.
266+
test("WorksWithoutBucketsWithNonIntegralBucketEdges") {
267+
// Verify the case of buckets with nonintegral edges. See #SPARK-2862.
262268
val rdd = sc.parallelize(6 to 99)
263269
val (histogramBuckets, histogramResults) = rdd.histogram(9)
270+
// Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ...
264271
val expectedHistogramResults =
265-
Array(11, 10, 11, 10, 10, 11, 10, 10, 11)
272+
Array(11, 10, 10, 11, 10, 10, 11, 10, 11)
266273
assert(histogramResults === expectedHistogramResults)
267274
assert(histogramBuckets(0) === 6.0)
268275
assert(histogramBuckets(9) === 99.0)
269276
}
270277

278+
test("WorksWithHugeRange") {
279+
val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30))
280+
val histogramResults = rdd.histogram(1000000)._2
281+
assert(histogramResults(0) === 1)
282+
assert(histogramResults(1) === 1)
283+
assert(histogramResults.last === 1)
284+
assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0))
285+
}
286+
271287
// Test the failure mode with an invalid RDD
272288
test("ThrowsExceptionOnInvalidRDDs") {
273289
// infinity

0 commit comments

Comments
 (0)