Skip to content

Commit f4c21f3

Browse files
committed
Reviewer comments
Added BernoulliBounds
1 parent a10e68d commit f4c21f3

File tree

5 files changed

+152
-107
lines changed

5 files changed

+152
-107
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,35 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
131131

132132
/**
133133
* Return a subset of this RDD sampled by key (via stratified sampling).
134+
*
135+
* Create a sample of this RDD using variable sampling rates for different keys as specified by
136+
* `fractions`, a key to sampling rate map.
137+
*
138+
* If `exact` is set to false, create the sample via simple random sampling, with one pass
139+
* over the RDD, to produce a sample of size that's approximately equal to the sum of
140+
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
141+
* the RDD to create a sample size that's exactly equal to the sum of
142+
* math.ceil(numItems * samplingRate) over all key values.
134143
*/
135144
def sampleByKey(withReplacement: Boolean,
136145
fractions: JMap[K, Double],
137146
exact: Boolean,
138147
seed: Long): JavaPairRDD[K, V] =
139148
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
140149

141-
142150
/**
143151
* Return a subset of this RDD sampled by key (via stratified sampling).
152+
*
153+
* Create a sample of this RDD using variable sampling rates for different keys as specified by
154+
* `fractions`, a key to sampling rate map.
155+
*
156+
* If `exact` is set to false, create the sample via simple random sampling, with one pass
157+
* over the RDD, to produce a sample of size that's approximately equal to the sum of
158+
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
159+
* the RDD to create a sample size that's exactly equal to the sum of
160+
* math.ceil(numItems * samplingRate) over all key values.
161+
*
162+
* Use Utils.random.nextLong as the default seed for the random number generator
144163
*/
145164
def sampleByKey(withReplacement: Boolean,
146165
fractions: JMap[K, Double],
@@ -149,17 +168,33 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
149168

150169
/**
151170
* Return a subset of this RDD sampled by key (via stratified sampling).
171+
*
172+
* Create a sample of this RDD using variable sampling rates for different keys as specified by
173+
* `fractions`, a key to sampling rate map.
174+
*
175+
* Produce a sample of size that's approximately equal to the sum of
176+
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
177+
* simple random sampling.
152178
*/
153179
def sampleByKey(withReplacement: Boolean,
154180
fractions: JMap[K, Double],
155181
seed: Long): JavaPairRDD[K, V] =
156-
sampleByKey(withReplacement, fractions, true, seed)
182+
sampleByKey(withReplacement, fractions, false, seed)
157183

158184
/**
159185
* Return a subset of this RDD sampled by key (via stratified sampling).
186+
*
187+
* Create a sample of this RDD using variable sampling rates for different keys as specified by
188+
* `fractions`, a key to sampling rate map.
189+
*
190+
* Produce a sample of size that's approximately equal to the sum of
191+
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
192+
* simple random sampling.
193+
*
194+
* Use Utils.random.nextLong as the default seed for the random number generator
160195
*/
161196
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
162-
sampleByKey(withReplacement, fractions, true, Utils.random.nextLong)
197+
sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
163198

164199
/**
165200
* Return the union of this RDD and another one. Any identical elements will appear multiple

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
4343
import org.apache.spark.partial.{BoundedDouble, PartialResult}
4444
import org.apache.spark.serializer.Serializer
4545
import org.apache.spark.util.Utils
46-
import org.apache.spark.util.random.StratifiedSampler
46+
import org.apache.spark.util.random.StratifiedSamplingUtils
4747

4848
/**
4949
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -195,32 +195,36 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
195195
/**
196196
* Return a subset of this RDD sampled by key (via stratified sampling).
197197
*
198-
* If exact set to true, we guarantee, with high probability, a sample size =
199-
* math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
200-
* that share the same key). When sampling without replacement, we need one additional pass over
201-
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
202-
* need two additional passes.
198+
* Create a sample of this RDD using variable sampling rates for different keys as specified by
199+
* `fractions`, a key to sampling rate map.
200+
*
201+
* If `exact` is set to false, create the sample via simple random sampling, with one pass
202+
* over the RDD, to produce a sample of size that's approximately equal to the sum of
203+
* math.ceil(numItems * samplingRate) over all key values; otherwise, use
204+
* additional passes over the RDD to create a sample size that's exactly equal to the sum of
205+
* math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
206+
* without replacement, we need one additional pass over the RDD to guarantee sample size;
207+
* when sampling with replacement, we need two additional passes.
203208
*
204209
* @param withReplacement whether to sample with or without replacement
205210
* @param fractions map of specific keys to sampling rates
206211
* @param seed seed for the random number generator
207-
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
212+
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
208213
* @return RDD containing the sampled subset
209214
*/
210215
def sampleByKey(withReplacement: Boolean,
211216
fractions: Map[K, Double],
212-
exact: Boolean = true,
217+
exact: Boolean = false,
213218
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
214219

215-
require(fractions.forall {case(k, v) => v >= 0.0}, "Invalid sampling rates.")
220+
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
216221

217222
val samplingFunc = if (withReplacement) {
218-
val counts = if (exact) Some(this.countByKey()) else None
219-
StratifiedSampler.getPoissonSamplingFunction(self, fractions, exact, counts, seed)
223+
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
220224
} else {
221-
StratifiedSampler.getBernoulliSamplingFunction(self, fractions, exact, seed)
225+
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
222226
}
223-
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning=true)
227+
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
224228
}
225229

226230
/**

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@ private[spark] object SamplingUtils {
5252
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
5353
math.max(1e-10, fraction + numStDev * math.sqrt(fraction / total))
5454
} else {
55-
val delta = 1e-4
56-
val gamma = - math.log(delta) / total
57-
math.min(1,
58-
math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
55+
BernoulliBounds.getLowerBound(1e-4, total, fraction)
5956
}
6057
}
6158
}
@@ -125,3 +122,21 @@ private[spark] object PoissonBounds {
125122
ub
126123
}
127124
}
125+
126+
127+
private[spark] object BernoulliBounds {
128+
129+
val minSamplingRate = 1e-10
130+
131+
def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
132+
val gamma = - math.log(delta) / n * (2.0 / 3.0)
133+
math.max(minSamplingRate,
134+
fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction))
135+
}
136+
137+
def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
138+
val gamma = - math.log(delta) / n
139+
math.min(1,
140+
math.max(minSamplingRate, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
141+
}
142+
}

0 commit comments

Comments
 (0)