@@ -26,6 +26,9 @@ import scala.Some
26
26
import org .apache .spark .rdd .RDD
27
27
28
28
private [spark] object StratifiedSampler extends Logging {
29
+ /**
30
+ * Returns the function used by aggregate to collect sampling statistics for each partition.
31
+ */
29
32
def getSeqOp [K , V ](withReplacement : Boolean ,
30
33
fractionByKey : (K => Double ),
31
34
counts : Option [Map [K , Long ]]): ((TaskContext , Result [K ]),(K , V )) => Result [K ] = {
@@ -43,9 +46,9 @@ private[spark] object StratifiedSampler extends Logging {
43
46
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
44
47
val n = counts.get(item._1)
45
48
val s = math.ceil(n * fraction).toLong
46
- val lmbd1 = PB .getLambda1 (s)
49
+ val lmbd1 = PB .getLowerBound (s)
47
50
val minCount = PB .getMinCount(lmbd1)
48
- val lmbd2 = if (lmbd1 == 0 ) PB .getLambda2 (s) else PB .getLambda2 (s - minCount)
51
+ val lmbd2 = if (lmbd1 == 0 ) PB .getUpperBound (s) else PB .getUpperBound (s - minCount)
49
52
val q1 = lmbd1 / n
50
53
val q2 = lmbd2 / n
51
54
stratum.q1 = Some (q1)
@@ -60,6 +63,8 @@ private[spark] object StratifiedSampler extends Logging {
60
63
stratum.addToWaitList(ArrayBuffer .fill(x2)(rng.nextUniform(0.0 , 1.0 )))
61
64
}
62
65
} else {
66
+ // We use the streaming version of the algorithm for sampling without replacement.
67
+ // Hence, q1 and q2 change on every iteration.
63
68
val g1 = - math.log(delta) / stratum.numItems
64
69
val g2 = (2.0 / 3.0 ) * g1
65
70
val q1 = math.max(0 , fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
@@ -79,7 +84,11 @@ private[spark] object StratifiedSampler extends Logging {
79
84
}
80
85
}
81
86
82
- def getCombOp [K ](): (Result [K ], Result [K ]) => Result [K ] = {
87
+ /**
88
+ * Returns the function used by aggregate to combine results from different partitions, as
89
+ * returned by seqOp.
90
+ */
91
+ def getCombOp [K ](): (Result [K ], Result [K ]) => Result [K ] = {
83
92
(r1 : Result [K ], r2 : Result [K ]) => {
84
93
// take union of both key sets in case one partition doesn't contain all keys
85
94
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
@@ -100,6 +109,10 @@ private[spark] object StratifiedSampler extends Logging {
100
109
}
101
110
}
102
111
112
+ /**
113
+ * Given the result returned by the aggregate function, we need to determine the threshold used
114
+ * to accept items to generate the exact sample size.
115
+ */
103
116
def computeThresholdByKey [K ](finalResult : Map [K , Stratum ], fractionByKey : (K => Double )):
104
117
(K => Double ) = {
105
118
val thresholdByKey = new mutable.HashMap [K , Double ]()
@@ -122,11 +135,15 @@ private[spark] object StratifiedSampler extends Logging {
122
135
thresholdByKey
123
136
}
124
137
125
- def computeThresholdByKey [K ](finalResult : Map [K , String ]): (K => String ) = {
126
- finalResult
127
- }
128
-
129
- def getBernoulliSamplingFunction [K , V ](rdd: RDD [(K , V )],
138
+ /**
139
+ * Return the per partition sampling function used for sampling without replacement.
140
+ *
141
+ * When exact sample size is required, we make an additional pass over the RDD to determine the
142
+ * exact sampling rate that guarantees sample size with high confidence.
143
+ *
144
+ * The sampling function has a unique seed per partition.
145
+ */
146
+ def getBernoulliSamplingFunction [K , V ](rdd : RDD [(K , V )],
130
147
fractionByKey : K => Double ,
131
148
exact : Boolean ,
132
149
seed : Long ): (Int , Iterator [(K , V )]) => Iterator [(K , V )] = {
@@ -146,6 +163,16 @@ private[spark] object StratifiedSampler extends Logging {
146
163
}
147
164
}
148
165
166
+ /**
167
+ * Return the per partition sampling function used for sampling with replacement.
168
+ *
169
+ * When exact sample size is required, we make two additional passed over the RDD to determine
170
+ * the exact sampling rate that guarantees sample size with high confidence. The first pass
171
+ * counts the number of items in each stratum (group of items with the same key) in the RDD, and
172
+ * the second pass uses the counts to determine exact sampling rates.
173
+ *
174
+ * The sampling function has a unique seed per partition.
175
+ */
149
176
def getPoissonSamplingFunction [K , V ](rdd: RDD [(K , V )],
150
177
fractionByKey : K => Double ,
151
178
exact : Boolean ,
@@ -191,6 +218,10 @@ private[spark] object StratifiedSampler extends Logging {
191
218
}
192
219
}
193
220
221
+ /**
222
+ * Object used by seqOp to keep track of the number of items accepted and items waitlisted per
223
+ * stratum, as well as the bounds for accepting and waitlisting items.
224
+ */
194
225
private [random] class Stratum (var numItems : Long = 0L , var numAccepted : Long = 0L )
195
226
extends Serializable {
196
227
@@ -205,13 +236,14 @@ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0
205
236
def addToWaitList (elem : Double ) = waitList += elem
206
237
207
238
def addToWaitList (elems : ArrayBuffer [Double ]) = waitList ++= elems
208
-
209
- override def toString () = {
210
- " numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
211
- " waitListSize:" + waitList.size
212
- }
213
239
}
214
240
241
+ /**
242
+ * Object used by seqOp and combOp to keep track of the sampling statistics for all strata.
243
+ *
244
+ * When used by seqOp for each partition, we also keep track of the partition ID in this object
245
+ * to make sure a single random number generator with a unique seed is used for each partition.
246
+ */
215
247
private [random] class Result [K ](var resultMap : Map [K , Stratum ],
216
248
var cachedPartitionId : Option [Int ] = None ,
217
249
val seed : Long )
0 commit comments