Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 1e68483

Browse files
committed
Making RNG more efficient for Bernoulli sampling usage in congestion
1 parent 834d24d commit 1e68483

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,16 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
143143
* @tparam T item type
144144
*/
145145
@DeveloperApi
146-
class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
146+
class BernoulliSampler[T: ClassTag](fraction: Double,
147+
private val rng: Random = RandomSampler.newDefaultRNG)
148+
extends RandomSampler[T, T] {
147149

148150
/** epsilon slop to avoid failure from floating point jitter */
149151
require(
150152
fraction >= (0.0 - RandomSampler.roundingEpsilon)
151153
&& fraction <= (1.0 + RandomSampler.roundingEpsilon),
152154
s"Sampling fraction ($fraction) must be on interval [0, 1]")
153155

154-
private val rng: Random = RandomSampler.newDefaultRNG
155-
156156
override def setSeed(seed: Long): Unit = rng.setSeed(seed)
157157

158158
override def sample(items: Iterator[T]): Iterator[T] = {

streaming/src/main/scala/org/apache/spark/streaming/receiver/CongestionStrategyImpl.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver
2020
import scala.collection.mutable.ArrayBuffer
2121
import java.util.concurrent.atomic.AtomicInteger
2222
import java.util.Random
23-
import org.apache.spark.util.random.BernoulliSampler
23+
import org.apache.spark.util.random.{RandomSampler, BernoulliSampler}
2424

2525
/**
2626
* This class provides a congestion strategy that ignores
@@ -79,6 +79,8 @@ class DropCongestionStrategy extends CongestionStrategy {
7979

8080
class SamplingCongestionStrategy extends CongestionStrategy {
8181

82+
private val rng = Random.newDefaultRNG
83+
8284
private val latestBound = new AtomicInteger(-1)
8385

8486
override def onBlockBoundUpdate(bound: Int): Unit = latestBound.set(bound)
@@ -89,7 +91,7 @@ class SamplingCongestionStrategy extends CongestionStrategy {
8991
val f = bound.toDouble / currentBuffer.size
9092
val samplees = currentBuffer.to
9193

92-
val sampled = new BernoulliSampler(f).sample(samplees.toIterator)
94+
val sampled = new BernoulliSampler(f, rng).sample(samplees.toIterator)
9395

9496
currentBuffer.clear()
9597
currentBuffer ++= sampled

0 commit comments

Comments
 (0)