Skip to content

Commit 0214a76

Browse files
committed
cleanUp
Addressed reviewer comments and added better documentation of code. Added commons-math3 as a dependency of spark (okay’ed by Matei). “mvm clean install” compiled. Recovered files that were reverted by accident in the merge. TODOs: figure out API for sampleByKeyExact and update Java, Python, and the markdown file accordingly.
1 parent 90d94c0 commit 0214a76

File tree

5 files changed

+129
-20
lines changed

5 files changed

+129
-20
lines changed

core/pom.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070
<dependency>
7171
<groupId>org.apache.commons</groupId>
7272
<artifactId>commons-math3</artifactId>
73-
<version>3.3</version>
74-
<scope>test</scope>
7573
</dependency>
7674
<dependency>
7775
<groupId>com.google.code.findbugs</groupId>

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,27 @@ abstract class RDD[T: ClassTag](
874874
jobResult
875875
}
876876

877+
/**
878+
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
879+
* aggregation for each partition.
880+
*/
881+
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U,
882+
combOp: (U, U) => U): U = {
883+
// Clone the zero value since we will also be serializing it as part of tasks
884+
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
885+
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
886+
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
887+
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
888+
(arg1._1, combOp(arg1._2, arg1._2))
889+
val cleanSeqOp = sc.clean(paddedSeqOp)
890+
val cleanCombOp = sc.clean(paddedcombOp)
891+
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
892+
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
893+
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
894+
sc.runJob(this, aggregatePartition, mergeResult)
895+
jobResult
896+
}
897+
877898
/**
878899
* Return the number of elements in the RDD.
879900
*/

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.util.random
1919

20-
import org.apache.commons.math3.distribution.{PoissonDistribution, NormalDistribution}
20+
import org.apache.commons.math3.distribution.PoissonDistribution
2121

2222
private[spark] object SamplingUtils {
2323

@@ -43,7 +43,7 @@ private[spark] object SamplingUtils {
4343
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
4444
*/
4545
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
46-
withReplacement: Boolean): Double = {
46+
withReplacement: Boolean): Double = {
4747
val fraction = sampleSizeLowerBound.toDouble / total
4848
if (withReplacement) {
4949
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
@@ -56,12 +56,29 @@ private[spark] object SamplingUtils {
5656
}
5757
}
5858

59+
/**
60+
* Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
61+
* sample sizes with high confidence when sampling with replacement.
62+
*
63+
* The algorithm for guaranteeing sample size instantly accepts items whose associated value drawn
64+
* from Pois(s) is less than the lower bound and puts items whose value is between the lower and
65+
* upper bound in a waitlist. The final sample is consisted of all items accepted on the fly and a
66+
* portion of the waitlist needed to make the exact sample size.
67+
*/
5968
private[spark] object PoissonBounds {
6069

6170
val delta = 1e-4 / 3.0
62-
val phi = new NormalDistribution().cumulativeProbability(1.0 - delta)
6371

64-
def getLambda1(s: Double): Double = {
72+
/**
73+
* Compute the threshold for accepting items on the fly. The threshold value is a fairly small
74+
* number, which means if the item has an associated value < threshold, it is highly likely to
75+
* be in the final sample. Hence we accept items with values less than the returned value of this
76+
* function instantly.
77+
*
78+
* @param s sample size
79+
* @return threshold for accepting items on the fly
80+
*/
81+
def getLowerBound(s: Double): Double = {
6582
var lb = math.max(0.0, s - math.sqrt(s / delta)) // Chebyshev's inequality
6683
var ub = s
6784
while (lb < ub - 1.0) {
@@ -79,7 +96,16 @@ private[spark] object PoissonBounds {
7996
poisson.inverseCumulativeProbability(delta)
8097
}
8198

82-
def getLambda2(s: Double): Double = {
99+
/**
100+
* Compute the threshold for waitlisting items. An item is waitlisted if its associated value is
101+
* greater than the lower bound determined above but below the upper bound computed here.
102+
* The value is computed such that we only need to keep log(s) items in the waitlist and still be
103+
* able to guarantee sample size with high confidence.
104+
*
105+
* @param s sample size
106+
* @return threshold for waitlisting the item
107+
*/
108+
def getUpperBound(s: Double): Double = {
83109
var lb = s
84110
var ub = s + math.sqrt(s / delta) // Chebyshev's inequality
85111
while (lb < ub - 1.0) {

core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ import scala.Some
2626
import org.apache.spark.rdd.RDD
2727

2828
private[spark] object StratifiedSampler extends Logging {
29+
/**
30+
* Returns the function used by aggregate to collect sampling statistics for each partition.
31+
*/
2932
def getSeqOp[K, V](withReplacement: Boolean,
3033
fractionByKey: (K => Double),
3134
counts: Option[Map[K, Long]]): ((TaskContext, Result[K]),(K, V)) => Result[K] = {
@@ -43,9 +46,9 @@ private[spark] object StratifiedSampler extends Logging {
4346
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
4447
val n = counts.get(item._1)
4548
val s = math.ceil(n * fraction).toLong
46-
val lmbd1 = PB.getLambda1(s)
49+
val lmbd1 = PB.getLowerBound(s)
4750
val minCount = PB.getMinCount(lmbd1)
48-
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
51+
val lmbd2 = if (lmbd1 == 0) PB.getUpperBound(s) else PB.getUpperBound(s - minCount)
4952
val q1 = lmbd1 / n
5053
val q2 = lmbd2 / n
5154
stratum.q1 = Some(q1)
@@ -60,6 +63,8 @@ private[spark] object StratifiedSampler extends Logging {
6063
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
6164
}
6265
} else {
66+
// We use the streaming version of the algorithm for sampling without replacement.
67+
// Hence, q1 and q2 change on every iteration.
6368
val g1 = - math.log(delta) / stratum.numItems
6469
val g2 = (2.0 / 3.0) * g1
6570
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
@@ -79,7 +84,11 @@ private[spark] object StratifiedSampler extends Logging {
7984
}
8085
}
8186

82-
def getCombOp[K](): (Result[K], Result[K]) => Result[K] = {
87+
/**
88+
* Returns the function used by aggregate to combine results from different partitions, as
89+
* returned by seqOp.
90+
*/
91+
def getCombOp[K](): (Result[K], Result[K]) => Result[K] = {
8392
(r1: Result[K], r2: Result[K]) => {
8493
// take union of both key sets in case one partition doesn't contain all keys
8594
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
@@ -100,6 +109,10 @@ private[spark] object StratifiedSampler extends Logging {
100109
}
101110
}
102111

112+
/**
113+
* Given the result returned by the aggregate function, we need to determine the threshold used
114+
* to accept items to generate the exact sample size.
115+
*/
103116
def computeThresholdByKey[K](finalResult: Map[K, Stratum], fractionByKey: (K => Double)):
104117
(K => Double) = {
105118
val thresholdByKey = new mutable.HashMap[K, Double]()
@@ -122,11 +135,15 @@ private[spark] object StratifiedSampler extends Logging {
122135
thresholdByKey
123136
}
124137

125-
def computeThresholdByKey[K](finalResult: Map[K, String]): (K => String) = {
126-
finalResult
127-
}
128-
129-
def getBernoulliSamplingFunction[K, V](rdd:RDD[(K, V)],
138+
/**
139+
* Return the per partition sampling function used for sampling without replacement.
140+
*
141+
* When exact sample size is required, we make an additional pass over the RDD to determine the
142+
* exact sampling rate that guarantees sample size with high confidence.
143+
*
144+
* The sampling function has a unique seed per partition.
145+
*/
146+
def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)],
130147
fractionByKey: K => Double,
131148
exact: Boolean,
132149
seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
@@ -146,6 +163,16 @@ private[spark] object StratifiedSampler extends Logging {
146163
}
147164
}
148165

166+
/**
167+
* Return the per partition sampling function used for sampling with replacement.
168+
*
169+
* When exact sample size is required, we make two additional passed over the RDD to determine
170+
* the exact sampling rate that guarantees sample size with high confidence. The first pass
171+
* counts the number of items in each stratum (group of items with the same key) in the RDD, and
172+
* the second pass uses the counts to determine exact sampling rates.
173+
*
174+
* The sampling function has a unique seed per partition.
175+
*/
149176
def getPoissonSamplingFunction[K, V](rdd:RDD[(K, V)],
150177
fractionByKey: K => Double,
151178
exact: Boolean,
@@ -191,6 +218,10 @@ private[spark] object StratifiedSampler extends Logging {
191218
}
192219
}
193220

221+
/**
222+
* Object used by seqOp to keep track of the number of items accepted and items waitlisted per
223+
* stratum, as well as the bounds for accepting and waitlisting items.
224+
*/
194225
private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L)
195226
extends Serializable {
196227

@@ -205,13 +236,14 @@ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0
205236
def addToWaitList(elem: Double) = waitList += elem
206237

207238
def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems
208-
209-
override def toString() = {
210-
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
211-
" waitListSize:" + waitList.size
212-
}
213239
}
214240

241+
/**
242+
* Object used by seqOp and combOp to keep track of the sampling statistics for all strata.
243+
*
244+
* When used by seqOp for each partition, we also keep track of the partition ID in this object
245+
* to make sure a single random number generator with a unique seed is used for each partition.
246+
*/
215247
private[random] class Result[K](var resultMap: Map[K, Stratum],
216248
var cachedPartitionId: Option[Int] = None,
217249
val seed: Long)

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,38 @@ class RDDSuite extends FunSuite with SharedSparkContext {
141141
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
142142
}
143143

144+
test("aggregateWithContext") {
145+
val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))
146+
val numPartitions = 2
147+
val pairs = sc.makeRDD(data, numPartitions)
148+
//determine the partitionId for each pair
149+
type StringMap = HashMap[String, Int]
150+
val partitions = pairs.collectPartitions()
151+
val offSets = new StringMap
152+
for (i <- 0 to numPartitions - 1) {
153+
partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)})
154+
}
155+
val emptyMap = new StringMap {
156+
override def default(key: String): Int = 0
157+
}
158+
val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => {
159+
val stringMap = arg1._2
160+
val tc = arg1._1
161+
stringMap(pair._1) += pair._2 + tc.partitionId
162+
stringMap
163+
}
164+
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
165+
for ((key, value) <- map2) {
166+
map1(key) += value
167+
}
168+
map1
169+
}
170+
val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps)
171+
val expected = Set(("a", 6), ("b", 2), ("c", 5))
172+
.map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))})
173+
assert(result.toSet === expected)
174+
}
175+
144176
test("basic caching") {
145177
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
146178
assert(rdd.collect().toList === List(1, 2, 3, 4))

0 commit comments

Comments
 (0)