Skip to content

Commit 49bbdcb

Browse files
yingjieMiaorxin
authored andcommitted
[Spark] RDD take() method: overestimate too much
In the comment (Line 1083), it says: "Otherwise, interpolate the number of partitions we need to try, but overestimate it by 50%." `(1.5 * num * partsScanned / buf.size).toInt` is the guess of "num of total partitions needed". In every iteration, we should consider the increment `(1.5 * num * partsScanned / buf.size).toInt - partsScanned` Existing implementation 'exponentially' grows `partsScanned ` ( roughly: `x_{n+1} >= (1.5 + 1) x_n`) This could be a performance problem. (unless this is the intended behavior) Author: yingjieMiao <[email protected]> Closes #2648 from yingjieMiao/rdd_take and squashes the following commits: d758218 [yingjieMiao] scala style fix a8e74bb [yingjieMiao] python style fix 4b6e777 [yingjieMiao] infix operator style fix 4391d3b [yingjieMiao] typo fix. 692f4e6 [yingjieMiao] cap numPartsToTry c4483dc [yingjieMiao] style fix 1d2c410 [yingjieMiao] also change in rdd.py and AsyncRDD d31ff7e [yingjieMiao] handle the edge case after 1 iteration a2aa36b [yingjieMiao] RDD take method: overestimate too much
1 parent 39ccaba commit 49bbdcb

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
7878
// greater than totalParts because we actually cap it at totalParts in runJob.
7979
var numPartsToTry = 1
8080
if (partsScanned > 0) {
81-
// If we didn't find any rows after the first iteration, just try all partitions next.
81+
// If we didn't find any rows after the previous iteration, quadruple and retry.
8282
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
83-
// by 50%.
83+
// by 50%. We also cap the estimation in the end.
8484
if (results.size == 0) {
85-
numPartsToTry = totalParts - 1
85+
numPartsToTry = partsScanned * 4
8686
} else {
87-
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
87+
// the left side of max is >=1 whenever partsScanned >= 2
88+
numPartsToTry = Math.max(1,
89+
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
90+
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
8891
}
8992
}
90-
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
9193

9294
val left = num - results.size
9395
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag](
10791079
// greater than totalParts because we actually cap it at totalParts in runJob.
10801080
var numPartsToTry = 1
10811081
if (partsScanned > 0) {
1082-
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
1082+
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
10831083
// interpolate the number of partitions we need to try, but overestimate it by 50%.
1084+
// We also cap the estimation in the end.
10841085
if (buf.size == 0) {
10851086
numPartsToTry = partsScanned * 4
10861087
} else {
1087-
numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
1088+
// the left side of max is >=1 whenever partsScanned >= 2
1089+
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
1090+
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
10881091
}
10891092
}
1090-
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
10911093

10921094
val left = num - buf.size
10931095
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

python/pyspark/rdd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,10 +1070,13 @@ def take(self, num):
10701070
# If we didn't find any rows after the previous iteration,
10711071
# quadruple and retry. Otherwise, interpolate the number of
10721072
# partitions we need to try, but overestimate it by 50%.
1073+
# We also cap the estimation in the end.
10731074
if len(items) == 0:
10741075
numPartsToTry = partsScanned * 4
10751076
else:
1076-
numPartsToTry = int(1.5 * num * partsScanned / len(items))
1077+
# the first paramter of max is >=1 whenever partsScanned >= 2
1078+
numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
1079+
numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)
10771080

10781081
left = num - len(items)
10791082

0 commit comments

Comments
 (0)