Skip to content

Commit eaf5771

Browse files
committed
bug fixes.
1 parent 17a381b commit eaf5771

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ private[spark] object SamplingUtils {
9191
*/
9292
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
9393
withReplacement: Boolean): Double = {
94-
val fraction = sampleSizeLowerBound.toDouble / total
9594
if (withReplacement) {
96-
PoissonBounds.getUpperBound(sampleSizeLowerBound)
95+
PoissonBounds.getUpperBound(sampleSizeLowerBound) / total
9796
} else {
98-
BernoulliBounds.getLowerBound(1e-4, total, fraction)
97+
val fraction = sampleSizeLowerBound.toDouble / total
98+
BinomialBounds.getUpperBound(1e-4, total, fraction)
9999
}
100100
}
101101
}
@@ -138,25 +138,25 @@ private[spark] object PoissonBounds {
138138
* Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
139139
* sample size with high confidence when sampling without replacement.
140140
*/
141-
private[spark] object BernoulliBounds {
141+
private[spark] object BinomialBounds {
142142

143143
val minSamplingRate = 1e-10
144144

145145
/**
146-
* Returns a threshold such that if we apply Bernoulli sampling with that threshold, it is very
147-
* unlikely to sample less than `fraction * n` items out of `n` items.
146+
* Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
147+
* it is very unlikely to have more than `fraction * n` successes.
148148
*/
149-
def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
149+
def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
150150
val gamma = - math.log(delta) / n * (2.0 / 3.0)
151151
math.max(minSamplingRate,
152152
fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction))
153153
}
154154

155155
/**
156-
* Returns a threshold such that if we apply Bernoulli sampling with that threshold, it is very
157-
* unlikely to sample more than `fraction * n` items out of `n` items.
156+
* Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
157+
* it is very unlikely to have less than `fraction * n` successes.
158158
*/
159-
def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
159+
def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
160160
val gamma = - math.log(delta) / n
161161
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
162162
}

core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717

1818
package org.apache.spark.util.random
1919

20-
import cern.jet.random.Poisson
21-
import cern.jet.random.engine.DRand
22-
2320
import scala.collection.Map
2421
import scala.collection.mutable
2522
import scala.collection.mutable.ArrayBuffer
23+
import scala.reflect.ClassTag
24+
25+
import cern.jet.random.Poisson
26+
import cern.jet.random.engine.DRand
2627

2728
import org.apache.spark.Logging
2829
import org.apache.spark.SparkContext._
2930
import org.apache.spark.rdd.RDD
3031

31-
import scala.reflect.ClassTag
32-
3332
/**
3433
* Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions.
3534
*
@@ -119,9 +118,9 @@ private[spark] object StratifiedSamplingUtils extends Logging {
119118
// using an extra pass over the RDD for computing the count.
120119
// Hence, acceptBound and waitListBound change on every iteration.
121120
acceptResult.acceptBound =
122-
BernoulliBounds.getUpperBound(delta, acceptResult.numItems, fraction)
121+
BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction)
123122
acceptResult.waitListBound =
124-
BernoulliBounds.getLowerBound(delta, acceptResult.numItems, fraction)
123+
BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction)
125124

126125
val x = rng.nextUniform()
127126
if (x < acceptResult.acceptBound) {

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
168168
val stratifiedData = data.keyBy(stratifier(fractionPositive))
169169

170170
val samplingRate = 0.1
171-
val seed = defaultSeed
172-
checkAllCombos(stratifiedData, samplingRate, seed, n)
171+
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
173172
}
174173

175174
// vary fractionPositive
@@ -179,8 +178,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
179178
val stratifiedData = data.keyBy(stratifier(fractionPositive))
180179

181180
val samplingRate = 0.1
182-
val seed = defaultSeed
183-
checkAllCombos(stratifiedData, samplingRate, seed, n)
181+
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
184182
}
185183

186184
// Use the same data for the rest of the tests
@@ -197,8 +195,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
197195

198196
// vary sampling rate
199197
for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
200-
val seed = defaultSeed
201-
checkAllCombos(stratifiedData, samplingRate, seed, n)
198+
checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
202199
}
203200
}
204201

0 commit comments

Comments
 (0)