@@ -54,7 +54,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
54
54
* @see [[org.apache.spark.rdd.RDD#reduce ]]
55
55
*/
56
56
def treeReduce (f : (T , T ) => T , depth : Int ): T = {
57
- require(depth >= 1 , s " Depth must be greater than 1 but got $depth. " )
57
+ require(depth >= 1 , s " Depth must be greater than or equal to 1 but got $depth. " )
58
58
val cleanF = self.context.clean(f)
59
59
val reducePartition : Iterator [T ] => Option [T ] = iter => {
60
60
if (iter.hasNext) {
@@ -63,7 +63,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
63
63
None
64
64
}
65
65
}
66
- val local = self.mapPartitions(it => Iterator (reducePartition(it)))
66
+ val partiallyReduced = self.mapPartitions(it => Iterator (reducePartition(it)))
67
67
val op : (Option [T ], Option [T ]) => Option [T ] = (c, x) => {
68
68
if (c.isDefined && x.isDefined) {
69
69
Some (cleanF(c.get, x.get))
@@ -75,7 +75,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
75
75
None
76
76
}
77
77
}
78
- RDDFunctions .fromRDD(local ).treeAggregate(Option .empty[T ])(op, op, depth)
78
+ RDDFunctions .fromRDD(partiallyReduced ).treeAggregate(Option .empty[T ])(op, op, depth)
79
79
.getOrElse(throw new UnsupportedOperationException (" empty collection" ))
80
80
}
81
81
@@ -85,26 +85,28 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
85
85
* @see [[org.apache.spark.rdd.RDD#aggregate ]]
86
86
*/
87
87
def treeAggregate [U : ClassTag ](zeroValue : U )(
88
- seqOp : (U , T ) => U ,
89
- combOp : (U , U ) => U ,
90
- depth : Int ): U = {
91
- require(depth >= 1 , s " Depth must be greater than 1 but got $depth. " )
88
+ seqOp : (U , T ) => U ,
89
+ combOp : (U , U ) => U ,
90
+ depth : Int ): U = {
91
+ require(depth >= 1 , s " Depth must be greater than or equal to 1 but got $depth. " )
92
92
if (self.partitions.size == 0 ) {
93
93
return Utils .clone(zeroValue, self.context.env.closureSerializer.newInstance())
94
94
}
95
95
val cleanSeqOp = self.context.clean(seqOp)
96
96
val cleanCombOp = self.context.clean(combOp)
97
97
val aggregatePartition = (it : Iterator [T ]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
98
- var local = self.mapPartitions(it => Iterator (aggregatePartition(it)))
99
- var numPartitions = local .partitions.size
98
+ var partiallyAggregated = self.mapPartitions(it => Iterator (aggregatePartition(it)))
99
+ var numPartitions = partiallyAggregated .partitions.size
100
100
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2 )
101
+ // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
101
102
while (numPartitions > scale + numPartitions / scale) {
102
103
numPartitions /= scale
103
- local = local.mapPartitionsWithIndex { (i, iter) =>
104
- iter.map((i % numPartitions, _))
105
- }.reduceByKey(new HashPartitioner (numPartitions), cleanCombOp).values
104
+ val curNumPartitions = numPartitions
105
+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
106
+ iter.map((i % curNumPartitions, _))
107
+ }.reduceByKey(new HashPartitioner (curNumPartitions), cleanCombOp).values
106
108
}
107
- local .reduce(cleanCombOp)
109
+ partiallyAggregated .reduce(cleanCombOp)
108
110
}
109
111
}
110
112
0 commit comments