Skip to content

Commit 5f07e9c

Browse files
author
Andrew Or
committed
Remove more return statements from scopes
1 parent 5e388ea commit 5f07e9c

File tree

1 file changed

+21
-17
lines changed
  • core/src/main/scala/org/apache/spark/rdd

1 file changed

+21
-17
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,11 @@ abstract class RDD[T: ClassTag](
434434
* @param seed seed for the random number generator
435435
* @return sample of specified size in an array
436436
*/
437+
// TODO: rewrite this without return statements so we can wrap it in a scope
437438
def takeSample(
438439
withReplacement: Boolean,
439440
num: Int,
440-
seed: Long = Utils.random.nextLong): Array[T] = withScope {
441+
seed: Long = Utils.random.nextLong): Array[T] = {
441442
val numStDev = 10.0
442443

443444
if (num < 0) {
@@ -1027,23 +1028,26 @@ abstract class RDD[T: ClassTag](
10271028
depth: Int = 2): U = withScope {
10281029
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
10291030
if (partitions.length == 0) {
1030-
return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
1031-
}
1032-
val cleanSeqOp = context.clean(seqOp)
1033-
val cleanCombOp = context.clean(combOp)
1034-
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
1035-
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
1036-
var numPartitions = partiallyAggregated.partitions.length
1037-
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
1038-
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
1039-
while (numPartitions > scale + numPartitions / scale) {
1040-
numPartitions /= scale
1041-
val curNumPartitions = numPartitions
1042-
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
1043-
iter.map((i % curNumPartitions, _))
1044-
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
1031+
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
1032+
} else {
1033+
val cleanSeqOp = context.clean(seqOp)
1034+
val cleanCombOp = context.clean(combOp)
1035+
val aggregatePartition =
1036+
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
1037+
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
1038+
var numPartitions = partiallyAggregated.partitions.length
1039+
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
1040+
// If creating an extra level doesn't help reduce
1041+
// the wall-clock time, we stop tree aggregation.
1042+
while (numPartitions > scale + numPartitions / scale) {
1043+
numPartitions /= scale
1044+
val curNumPartitions = numPartitions
1045+
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
1046+
(i, iter) => iter.map((i % curNumPartitions, _))
1047+
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
1048+
}
1049+
partiallyAggregated.reduce(cleanCombOp)
10451050
}
1046-
partiallyAggregated.reduce(cleanCombOp)
10471051
}
10481052

10491053
/**

0 commit comments

Comments
 (0)