@@ -434,10 +434,11 @@ abstract class RDD[T: ClassTag](
434
434
* @param seed seed for the random number generator
435
435
* @return sample of specified size in an array
436
436
*/
437
+ // TODO: rewrite this without return statements so we can wrap it in a scope
437
438
def takeSample (
438
439
withReplacement : Boolean ,
439
440
num : Int ,
440
- seed : Long = Utils .random.nextLong): Array [T ] = withScope {
441
+ seed : Long = Utils .random.nextLong): Array [T ] = {
441
442
val numStDev = 10.0
442
443
443
444
if (num < 0 ) {
@@ -1027,23 +1028,26 @@ abstract class RDD[T: ClassTag](
1027
1028
depth : Int = 2 ): U = withScope {
1028
1029
require(depth >= 1 , s " Depth must be greater than or equal to 1 but got $depth. " )
1029
1030
if (partitions.length == 0 ) {
1030
- return Utils .clone(zeroValue, context.env.closureSerializer.newInstance())
1031
- }
1032
- val cleanSeqOp = context.clean(seqOp)
1033
- val cleanCombOp = context.clean(combOp)
1034
- val aggregatePartition = (it : Iterator [T ]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
1035
- var partiallyAggregated = mapPartitions(it => Iterator (aggregatePartition(it)))
1036
- var numPartitions = partiallyAggregated.partitions.length
1037
- val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2 )
1038
- // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
1039
- while (numPartitions > scale + numPartitions / scale) {
1040
- numPartitions /= scale
1041
- val curNumPartitions = numPartitions
1042
- partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
1043
- iter.map((i % curNumPartitions, _))
1044
- }.reduceByKey(new HashPartitioner (curNumPartitions), cleanCombOp).values
1031
+ Utils .clone(zeroValue, context.env.closureSerializer.newInstance())
1032
+ } else {
1033
+ val cleanSeqOp = context.clean(seqOp)
1034
+ val cleanCombOp = context.clean(combOp)
1035
+ val aggregatePartition =
1036
+ (it : Iterator [T ]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
1037
+ var partiallyAggregated = mapPartitions(it => Iterator (aggregatePartition(it)))
1038
+ var numPartitions = partiallyAggregated.partitions.length
1039
+ val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2 )
1040
+ // If creating an extra level doesn't help reduce
1041
+ // the wall-clock time, we stop tree aggregation.
1042
+ while (numPartitions > scale + numPartitions / scale) {
1043
+ numPartitions /= scale
1044
+ val curNumPartitions = numPartitions
1045
+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
1046
+ (i, iter) => iter.map((i % curNumPartitions, _))
1047
+ }.reduceByKey(new HashPartitioner (curNumPartitions), cleanCombOp).values
1048
+ }
1049
+ partiallyAggregated.reduce(cleanCombOp)
1045
1050
}
1046
- partiallyAggregated.reduce(cleanCombOp)
1047
1051
}
1048
1052
1049
1053
/**
0 commit comments