Skip to content

Commit fe42a5e

Browse files
committed
add treeAggregate
1 parent 23a12ce commit fe42a5e

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,30 @@ abstract class RDD[T: ClassTag](
862862
jobResult
863863
}
864864

865+
@DeveloperApi
866+
def treeAggregate[U: ClassTag](zeroValue: U)(
867+
seqOp: (U, T) => U,
868+
combOp: (U, U) => U,
869+
level: Int): U = {
870+
require(level >= 1, s"Level must be greater than 1 but got $level.")
871+
if (this.partitions.size == 0) {
872+
return Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
873+
}
874+
val cleanSeqOp = sc.clean(seqOp)
875+
val cleanCombOp = sc.clean(combOp)
876+
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
877+
var local = this.mapPartitions(it => Iterator(aggregatePartition(it)))
878+
var numPartitions = local.partitions.size
879+
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / level)).toInt, 2)
880+
while (numPartitions > scale + numPartitions / scale) {
881+
numPartitions /= scale
882+
local = local.mapPartitionsWithIndex { (i, iter) =>
883+
iter.map((i % numPartitions, _))
884+
}.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values
885+
}
886+
local.reduce(cleanCombOp)
887+
}
888+
865889
/**
866890
* Return the number of elements in the RDD.
867891
*/

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,4 +769,14 @@ class RDDSuite extends FunSuite with SharedSparkContext {
769769
mutableDependencies += dep
770770
}
771771
}
772+
773+
test("treeAggregate") {
774+
val rdd = sc.makeRDD(-1000 until 1000, 10)
775+
def seqOp = (c: Long, x: Int) => c + x
776+
def combOp = (c1: Long, c2: Long) => c1 + c2
777+
for (level <- 1 until 10) {
778+
val sum = rdd.treeAggregate(0L)(seqOp, combOp, level)
779+
assert(sum === -1000L)
780+
}
781+
}
772782
}

0 commit comments

Comments
 (0)