Skip to content

Commit a10e68d

Browse files
committed
style fix
1 parent a2bf756 commit a10e68d

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@ private[spark] object PoissonBounds {
9696
}
9797

9898
def getMinCount(lmbd: Double): Double = {
99-
if (lmbd == 0) return 0
100-
val poisson = new PoissonDistribution(lmbd, epsilon)
101-
poisson.inverseCumulativeProbability(delta)
99+
if (lmbd == 0) {
100+
0
101+
} else {
102+
val poisson = new PoissonDistribution(lmbd, epsilon)
103+
poisson.inverseCumulativeProbability(delta)
104+
}
102105
}
103106

104107
/**

core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,12 @@ private[spark] object StratifiedSampler extends Logging {
116116
// We use the streaming version of the algorithm for sampling without replacement to avoid
117117
// using an extra pass over the RDD for computing the count.
118118
// Hence, acceptBound and waitListBound change on every iteration.
119-
val g1 = - math.log(delta) / stratum.numItems // gamma1
120-
val g2 = (2.0 / 3.0) * g1 // gamma 2
121-
stratum.acceptBound = math.max(0, fraction + g2 - math.sqrt(g2 * g2 + 3 * g2 * fraction))
122-
stratum.waitListBound = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
119+
val gamma1 = - math.log(delta) / stratum.numItems
120+
val gamma2 = (2.0 / 3.0) * gamma1
121+
stratum.acceptBound = math.max(0,
122+
fraction + gamma2 - math.sqrt(gamma2 * gamma2 + 3 * gamma2 * fraction))
123+
stratum.waitListBound = math.min(1,
124+
fraction + gamma1 + math.sqrt(gamma1 * gamma1 + 2 * gamma1 * fraction))
123125

124126
val x = rng.nextUniform(0.0, 1.0)
125127
if (x < stratum.acceptBound) {
@@ -137,20 +139,20 @@ private[spark] object StratifiedSampler extends Logging {
137139
* Returns the function used combine results returned by seqOp from different partitions.
138140
*/
139141
def getCombOp[K]: (MMap[K, Stratum], MMap[K, Stratum]) => MMap[K, Stratum] = {
140-
(r1: MMap[K, Stratum], r2: MMap[K, Stratum]) => {
142+
(result1: MMap[K, Stratum], result2: MMap[K, Stratum]) => {
141143
// take union of both key sets in case one partition doesn't contain all keys
142-
for (key <- r1.keySet.union(r2.keySet)) {
143-
// Use r2 to keep the combined result since r1 is usual empty
144-
val entry1 = r1.get(key)
145-
if (r2.contains(key)) {
146-
r2(key).merge(entry1)
144+
for (key <- result1.keySet.union(result2.keySet)) {
145+
// Use result2 to keep the combined result since r1 is usual empty
146+
val entry1 = result1.get(key)
147+
if (result2.contains(key)) {
148+
result2(key).merge(entry1)
147149
} else {
148150
if (entry1.isDefined) {
149-
r2 += (key -> entry1.get)
151+
result2 += (key -> entry1.get)
150152
}
151153
}
152154
}
153-
r2
155+
result2
154156
}
155157
}
156158

@@ -237,10 +239,9 @@ private[spark] object StratifiedSampler extends Logging {
237239
rng.reSeed(seed + idx)
238240
iter.flatMap { item =>
239241
val key = item._1
240-
val q1 = finalResult(key).acceptBound
241-
val q2 = finalResult(key).waitListBound
242-
val copiesAccepted = if (q1 == 0) 0L else rng.nextPoisson(q1)
243-
val copiesWailisted = rng.nextPoisson(q2).toInt
242+
val acceptBound = finalResult(key).acceptBound
243+
val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
244+
val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound).toInt
244245
val copiesInSample = copiesAccepted +
245246
(0 until copiesWailisted).count(i => rng.nextUniform(0.0, 1.0) < thresholdByKey(key))
246247
if (copiesInSample > 0) {

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,11 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
8888
(x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
8989
}
9090

91-
def checkSize(exact: Boolean, withReplacement: Boolean,
92-
expected: Long, actual: Long, p: Double): Boolean = {
91+
def checkSize(exact: Boolean,
92+
withReplacement: Boolean,
93+
expected: Long,
94+
actual: Long,
95+
p: Double): Boolean = {
9396
if (exact) {
9497
return expected == actual
9598
}
@@ -110,8 +113,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
110113
val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
111114
val sampleCounts = sample.countByKey()
112115
val takeSample = sample.collect()
113-
assert(sampleCounts.forall({case(k,v) =>
114-
checkSize(exact, false, expectedSampleSize(k), v, samplingRate)}))
116+
assert(sampleCounts.forall {case(k,v) =>
117+
checkSize(exact, false, expectedSampleSize(k), v, samplingRate)})
115118
assert(takeSample.size === takeSample.toSet.size)
116119
assert(takeSample.forall(x => 1 <= x._2 && x._2 <= n), s"elements not in [1, $n]")
117120
}
@@ -128,9 +131,9 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
128131
val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
129132
val sampleCounts = sample.countByKey()
130133
val takeSample = sample.collect()
131-
assert(sampleCounts.forall({case(k,v) =>
132-
checkSize(exact, true, expectedSampleSize(k), v, samplingRate)}))
133-
val groupedByKey = takeSample.groupBy({case(k, v) => k})
134+
assert(sampleCounts.forall {case(k,v) =>
135+
checkSize(exact, true, expectedSampleSize(k), v, samplingRate)})
136+
val groupedByKey = takeSample.groupBy {case(k, v) => k}
134137
for ((key, v) <- groupedByKey) {
135138
if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
136139
// sample large enough for there to be repeats with high likelihood
@@ -146,8 +149,10 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
146149
assert(takeSample.forall(x => 1 <= x._2 && x._2 <= n), s"elements not in [1, $n]")
147150
}
148151

149-
def checkAllCombos(stratifiedData: RDD[(String, Int)], samplingRate: Double,
150-
seed: Long, n: Long) {
152+
def checkAllCombos(stratifiedData: RDD[(String, Int)],
153+
samplingRate: Double,
154+
seed: Long,
155+
n: Long) = {
151156
takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n)
152157
takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n)
153158
takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)

0 commit comments

Comments
 (0)