Skip to content

Commit b04b96a

Browse files
committed
address comments
1 parent 9bcc5d3 commit b04b96a

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class RowMatrix(
8080
private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
8181
val n = numCols().toInt
8282
val vbr = rows.context.broadcast(v)
83-
rows.aggregate(BDV.zeros[Double](n))(
83+
rows.treeAggregate(BDV.zeros[Double](n))(
8484
seqOp = (U, r) => {
8585
val rBrz = r.toBreeze
8686
val a = rBrz.dot(vbr.value)
@@ -93,8 +93,8 @@ class RowMatrix(
9393
}
9494
U
9595
},
96-
combOp = (U1, U2) => U1 += U2
97-
)
96+
combOp = (U1, U2) => U1 += U2,
97+
depth = 2)
9898
}
9999

100100
/**

mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
5454
* @see [[org.apache.spark.rdd.RDD#reduce]]
5555
*/
5656
def treeReduce(f: (T, T) => T, depth: Int): T = {
57-
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
57+
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
5858
val cleanF = self.context.clean(f)
5959
val reducePartition: Iterator[T] => Option[T] = iter => {
6060
if (iter.hasNext) {
@@ -63,7 +63,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
6363
None
6464
}
6565
}
66-
val local = self.mapPartitions(it => Iterator(reducePartition(it)))
66+
val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
6767
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
6868
if (c.isDefined && x.isDefined) {
6969
Some(cleanF(c.get, x.get))
@@ -75,7 +75,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
7575
None
7676
}
7777
}
78-
RDDFunctions.fromRDD(local).treeAggregate(Option.empty[T])(op, op, depth)
78+
RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
7979
.getOrElse(throw new UnsupportedOperationException("empty collection"))
8080
}
8181

@@ -85,26 +85,28 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
8585
* @see [[org.apache.spark.rdd.RDD#aggregate]]
8686
*/
8787
def treeAggregate[U: ClassTag](zeroValue: U)(
88-
seqOp: (U, T) => U,
89-
combOp: (U, U) => U,
90-
depth: Int): U = {
91-
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
88+
seqOp: (U, T) => U,
89+
combOp: (U, U) => U,
90+
depth: Int): U = {
91+
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
9292
if (self.partitions.size == 0) {
9393
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
9494
}
9595
val cleanSeqOp = self.context.clean(seqOp)
9696
val cleanCombOp = self.context.clean(combOp)
9797
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
98-
var local = self.mapPartitions(it => Iterator(aggregatePartition(it)))
99-
var numPartitions = local.partitions.size
98+
var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
99+
var numPartitions = partiallyAggregated.partitions.size
100100
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
101+
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
101102
while (numPartitions > scale + numPartitions / scale) {
102103
numPartitions /= scale
103-
local = local.mapPartitionsWithIndex { (i, iter) =>
104-
iter.map((i % numPartitions, _))
105-
}.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values
104+
val curNumPartitions = numPartitions
105+
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
106+
iter.map((i % curNumPartitions, _))
107+
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
106108
}
107-
local.reduce(cleanCombOp)
109+
partiallyAggregated.reduce(cleanCombOp)
108110
}
109111
}
110112

0 commit comments

Comments
 (0)