Skip to content

Commit 254e03c

Browse files
committed
minor fixes and Java API.
punting on python for now. moved aggregateWithContext out of RDD
1 parent 4ad516b commit 254e03c

File tree

6 files changed

+85
-80
lines changed

6 files changed

+85
-80
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.api.java
1919

20-
import java.util.{Comparator, List => JList}
20+
import java.util.{Comparator, List => JList, Map => JMap}
2121
import java.lang.{Iterable => JIterable}
2222

2323
import scala.collection.JavaConversions._
@@ -129,6 +129,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
129129
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
130130
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
131131

132+
/**
133+
* Return a subset of this RDD sampled by key (via stratified sampling).
134+
*/
135+
def sampleByKey(withReplacement: Boolean,
136+
fractions: JMap[K, Double],
137+
exact: Boolean,
138+
seed: Long): JavaPairRDD[K, V] =
139+
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
140+
141+
142+
/**
143+
* Return a subset of this RDD sampled by key (via stratified sampling).
144+
*/
145+
def sampleByKey(withReplacement: Boolean,
146+
fractions: JMap[K, Double],
147+
exact: Boolean): JavaPairRDD[K, V] =
148+
sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
149+
150+
/**
151+
* Return a subset of this RDD sampled by key (via stratified sampling).
152+
*/
153+
def sampleByKey(withReplacement: Boolean,
154+
fractions: JMap[K, Double],
155+
seed: Long): JavaPairRDD[K, V] =
156+
sampleByKey(withReplacement, fractions, true, seed)
157+
158+
/**
159+
* Return a subset of this RDD sampled by key (via stratified sampling).
160+
*/
161+
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
162+
sampleByKey(withReplacement, fractions, true, Utils.random.nextLong)
163+
132164
/**
133165
* Return the union of this RDD and another one. Any identical elements will appear multiple
134166
* times (use `.distinct()` to eliminate them).

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ package org.apache.spark.rdd
1919

2020
import java.nio.ByteBuffer
2121
import java.text.SimpleDateFormat
22-
import java.util.Date
23-
import java.util.{HashMap => JHashMap}
22+
import java.util.{Date, HashMap => JHashMap}
2423

24+
import scala.collection.{Map, mutable}
2525
import scala.collection.JavaConversions._
26-
import scala.collection.Map
27-
import scala.collection.mutable
2826
import scala.collection.mutable.ArrayBuffer
2927
import scala.reflect.ClassTag
3028

@@ -34,16 +32,14 @@ import org.apache.hadoop.fs.FileSystem
3432
import org.apache.hadoop.io.SequenceFile.CompressionType
3533
import org.apache.hadoop.io.compress.CompressionCodec
3634
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
37-
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob,
35+
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
3836
RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
39-
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
4037

4138
import org.apache.spark._
42-
import org.apache.spark.annotation.Experimental
43-
import org.apache.spark.deploy.SparkHadoopUtil
44-
import org.apache.spark.SparkHadoopWriter
4539
import org.apache.spark.Partitioner.defaultPartitioner
4640
import org.apache.spark.SparkContext._
41+
import org.apache.spark.annotation.Experimental
42+
import org.apache.spark.deploy.SparkHadoopUtil
4743
import org.apache.spark.partial.{BoundedDouble, PartialResult}
4844
import org.apache.spark.serializer.Serializer
4945
import org.apache.spark.util.Utils
@@ -216,24 +212,26 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
216212
* need two additional passes.
217213
*
218214
* @param withReplacement whether to sample with or without replacement
219-
* @param fractionByKey function mapping key to sampling rate
215+
* @param fractions map of specific keys to sampling rates
220216
* @param seed seed for the random number generator
221217
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
222218
* @return RDD containing the sampled subset
223219
*/
224220
def sampleByKey(withReplacement: Boolean,
225-
fractionByKey: Map[K, Double],
226-
seed: Long = Utils.random.nextLong,
227-
exact: Boolean = true): RDD[(K, V)]= {
228-
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
221+
fractions: Map[K, Double],
222+
exact: Boolean = true,
223+
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
224+
225+
require(fractions.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
226+
229227
if (withReplacement) {
230228
val counts = if (exact) Some(this.countByKey()) else None
231229
val samplingFunc =
232-
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
230+
StratifiedSampler.getPoissonSamplingFunction(self, fractions, exact, counts, seed)
233231
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
234232
} else {
235233
val samplingFunc =
236-
StratifiedSampler.getBernoulliSamplingFunction(self, fractionByKey, exact, seed)
234+
StratifiedSampler.getBernoulliSamplingFunction(self, fractions, exact, seed)
237235
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
238236
}
239237
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -875,27 +875,6 @@ abstract class RDD[T: ClassTag](
875875
jobResult
876876
}
877877

878-
/**
879-
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
880-
* aggregation for each partition.
881-
*/
882-
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U,
883-
combOp: (U, U) => U): U = {
884-
// Clone the zero value since we will also be serializing it as part of tasks
885-
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
886-
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
887-
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
888-
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
889-
(arg1._1, combOp(arg1._2, arg1._2))
890-
val cleanSeqOp = sc.clean(paddedSeqOp)
891-
val cleanCombOp = sc.clean(paddedcombOp)
892-
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
893-
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
894-
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
895-
sc.runJob(this, aggregatePartition, mergeResult)
896-
jobResult
897-
}
898-
899878
/**
900879
* Return the number of elements in the RDD.
901880
*/

core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,43 @@
1717

1818
package org.apache.spark.util.random
1919

20+
import scala.collection.{Map, mutable}
2021
import scala.collection.mutable.ArrayBuffer
21-
import scala.collection.{mutable, Map}
22+
import scala.reflect.ClassTag
23+
2224
import org.apache.commons.math3.random.RandomDataGenerator
23-
import org.apache.spark.{Logging, TaskContext}
24-
import org.apache.spark.util.random.{PoissonBounds => PB}
25-
import scala.Some
25+
import org.apache.spark.{Logging, SparkContext, TaskContext}
2626
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.util.Utils
28+
import org.apache.spark.util.random.{PoissonBounds => PB}
2729

2830
private[spark] object StratifiedSampler extends Logging {
31+
32+
/**
33+
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
34+
* aggregation for each partition. This function avoids creating an extra depth in the RDD
35+
* lineage, as opposed to using mapPartitionsWithId, which results in slightly improved run time.
36+
*/
37+
def aggregateWithContext[U: ClassTag, T: ClassTag](zeroValue: U)
38+
(rdd: RDD[T],
39+
seqOp: ((TaskContext, U), T) => U,
40+
combOp: (U, U) => U): U = {
41+
val sc: SparkContext = rdd.sparkContext
42+
// Clone the zero value since we will also be serializing it as part of tasks
43+
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
44+
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
45+
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
46+
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
47+
(arg1._1, combOp(arg1._2, arg1._2))
48+
val cleanSeqOp = sc.clean(paddedSeqOp)
49+
val cleanCombOp = sc.clean(paddedcombOp)
50+
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
51+
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
52+
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
53+
sc.runJob(rdd, aggregatePartition, mergeResult)
54+
jobResult
55+
}
56+
2957
/**
3058
* Returns the function used by aggregate to collect sampling statistics for each partition.
3159
*/
@@ -153,7 +181,7 @@ private[spark] object StratifiedSampler extends Logging {
153181
val seqOp = StratifiedSampler.getSeqOp[K,V](false, fractionByKey, None)
154182
val combOp = StratifiedSampler.getCombOp[K]()
155183
val zeroU = new Result[K](Map[K, Stratum](), seed = seed)
156-
val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
184+
val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap
157185
samplingRateByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey)
158186
}
159187
(idx: Int, iter: Iterator[(K, V)]) => {
@@ -183,7 +211,7 @@ private[spark] object StratifiedSampler extends Logging {
183211
val seqOp = StratifiedSampler.getSeqOp[K,V](true, fractionByKey, counts)
184212
val combOp = StratifiedSampler.getCombOp[K]()
185213
val zeroU = new Result[K](Map[K, Stratum](), seed = seed)
186-
val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
214+
val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap
187215
val thresholdByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey)
188216
(idx: Int, iter: Iterator[(K, V)]) => {
189217
val random = new RandomDataGenerator()

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
106106
n: Long) = {
107107
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
108108
math.ceil(count * samplingRate).toInt)
109-
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
110-
val sample = stratifiedData.sampleByKey(false, fractionByKey, seed, exact)
109+
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
110+
val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
111111
val sampleCounts = sample.countByKey()
112112
val takeSample = sample.collect()
113113
assert(sampleCounts.forall({case(k,v) =>
@@ -124,8 +124,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
124124
n: Long) = {
125125
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
126126
math.ceil(count * samplingRate).toInt)
127-
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
128-
val sample = stratifiedData.sampleByKey(true, fractionByKey, seed, exact)
127+
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
128+
val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
129129
val sampleCounts = sample.countByKey()
130130
val takeSample = sample.collect()
131131
assert(sampleCounts.forall({case(k,v) =>

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -141,38 +141,6 @@ class RDDSuite extends FunSuite with SharedSparkContext {
141141
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
142142
}
143143

144-
test("aggregateWithContext") {
145-
val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))
146-
val numPartitions = 2
147-
val pairs = sc.makeRDD(data, numPartitions)
148-
//determine the partitionId for each pair
149-
type StringMap = HashMap[String, Int]
150-
val partitions = pairs.collectPartitions()
151-
val offSets = new StringMap
152-
for (i <- 0 to numPartitions - 1) {
153-
partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)})
154-
}
155-
val emptyMap = new StringMap {
156-
override def default(key: String): Int = 0
157-
}
158-
val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => {
159-
val stringMap = arg1._2
160-
val tc = arg1._1
161-
stringMap(pair._1) += pair._2 + tc.partitionId
162-
stringMap
163-
}
164-
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
165-
for ((key, value) <- map2) {
166-
map1(key) += value
167-
}
168-
map1
169-
}
170-
val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps)
171-
val expected = Set(("a", 6), ("b", 2), ("c", 5))
172-
.map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))})
173-
assert(result.toSet === expected)
174-
}
175-
176144
test("basic caching") {
177145
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
178146
assert(rdd.collect().toList === List(1, 2, 3, 4))

0 commit comments

Comments
 (0)