Skip to content

Commit 304f636

Browse files
committed
Added simpler version of updateStateByKey API with initialRDD and test.
1 parent 9781135 commit 304f636

File tree

5 files changed

+72
-34
lines changed

5 files changed

+72
-34
lines changed

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -443,23 +443,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
443443
scalaFunc
444444
}
445445

446-
private def convertUpdateStateFunctionWithIterator[S]
447-
(in: JFunction2[JList[V], Optional[S], Optional[S]]):
448-
(Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)] = {
449-
val scalaFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)] = (iterator) => {
450-
iterator.flatMap { t =>
451-
val list: JList[V] = t._2
452-
val scalaState: Optional[S] = JavaUtils.optionToOptional(t._3)
453-
val result: Optional[S] = in.apply(list, scalaState)
454-
result.isPresent match {
455-
case true => Some((t._1, result.get()))
456-
case _ => None
457-
}
458-
}
459-
}
460-
scalaFunc
461-
}
462-
463446
/**
464447
* Return a new "state" DStream where the state for each key is updated by applying
465448
* the given function on the previous state of the key and the new values of each key.
@@ -526,8 +509,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
526509
initialRDD: JavaPairRDD[K, S]
527510
): JavaPairDStream[K, S] = {
528511
implicit val cm: ClassTag[S] = fakeClassTag
529-
dstream.updateStateByKey(convertUpdateStateFunctionWithIterator(updateFunc),
530-
partitioner, true, initialRDD)
512+
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD)
531513
}
532514

533515
/**

streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,28 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
416416
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
417417
}
418418

419+
/**
420+
* Return a new "state" DStream where the state for each key is updated by applying
421+
* the given function on the previous state of the key and the new values of the key.
422+
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
423+
* @param updateFunc State update function. If `this` function returns None, then
424+
* corresponding state key-value pair will be eliminated.
425+
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
426+
* DStream.
427+
* @param initialRDD initial state value of each key.
428+
* @tparam S State type
429+
*/
430+
def updateStateByKey[S: ClassTag](
431+
updateFunc: (Seq[V], Option[S]) => Option[S],
432+
partitioner: Partitioner,
433+
initialRDD: RDD[(K, S)]
434+
): DStream[(K, S)] = {
435+
val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
436+
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
437+
}
438+
updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
439+
}
440+
419441
/**
420442
* Return a new "state" DStream where the state for each key is updated by applying
421443
* the given function on the previous state of the key and the new values of each key.

streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
5151
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
5252
val i = iterator.map(t => {
5353
val itr = t._2._2.iterator
54-
val headOption = itr.hasNext match {
55-
case true => Some(itr.next())
56-
case false => None
57-
}
54+
val headOption = if(itr.hasNext) Some(itr.next) else None
5855
(t._1, t._2._1.toSeq, headOption)
5956
})
6057
updateFuncLocal(i)

streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -806,15 +806,17 @@ public void testUnion() {
806806
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
807807
* us to account for ordering variation within individual RDD's which occurs during windowing.
808808
*/
809-
public static <T extends Comparable<T>> void assertOrderInvariantEquals(
809+
public static <T> void assertOrderInvariantEquals(
810810
List<List<T>> expected, List<List<T>> actual) {
811+
List<Set<T>> expectedSets = new ArrayList<Set<T>>();
811812
for (List<T> list: expected) {
812-
Collections.sort(list);
813+
expectedSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
813814
}
815+
List<Set<T>> actualSets = new ArrayList<Set<T>>();
814816
for (List<T> list: actual) {
815-
Collections.sort(list);
817+
actualSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
816818
}
817-
Assert.assertEquals(expected, actual);
819+
Assert.assertEquals(expectedSets, actualSets);
818820
}
819821

820822

@@ -1252,12 +1254,12 @@ public void testUpdateStateByKeyWithInitial() {
12521254
JavaPairRDD<String, Integer> initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD);
12531255

12541256
List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
1255-
Arrays.asList(new Tuple2<String, Integer>("new york", 7),
1256-
new Tuple2<String, Integer>("california", 5)),
1257-
Arrays.asList(new Tuple2<String, Integer>("new york", 11),
1258-
new Tuple2<String, Integer>("california", 15)),
1259-
Arrays.asList(new Tuple2<String, Integer>("new york", 11),
1260-
new Tuple2<String, Integer>("california", 15)));
1257+
Arrays.asList(new Tuple2<String, Integer>("california", 4),
1258+
new Tuple2<String, Integer>("new york", 5)),
1259+
Arrays.asList(new Tuple2<String, Integer>("california", 14),
1260+
new Tuple2<String, Integer>("new york", 9)),
1261+
Arrays.asList(new Tuple2<String, Integer>("california", 14),
1262+
new Tuple2<String, Integer>("new york", 9)));
12611263

12621264
JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
12631265
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
@@ -1279,7 +1281,7 @@ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
12791281
JavaTestUtils.attachTestOutputStream(updated);
12801282
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
12811283

1282-
Assert.assertEquals(expected, result);
1284+
assertOrderInvariantEquals(expected, result);
12831285
}
12841286

12851287
@SuppressWarnings("unchecked")

streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,41 @@ class BasicOperationsSuite extends TestSuiteBase {
351351
testOperation(inputData, updateStateOperation, outputData, true)
352352
}
353353

354+
test("updateStateByKey - simple with initial value RDD") {
355+
val initial = Seq(("a", 1), ("c", 2))
356+
357+
val inputData =
358+
Seq(
359+
Seq("a"),
360+
Seq("a", "b"),
361+
Seq("a", "b", "c"),
362+
Seq("a", "b"),
363+
Seq("a"),
364+
Seq()
365+
)
366+
367+
val outputData =
368+
Seq(
369+
Seq(("a", 2), ("c", 2)),
370+
Seq(("a", 3), ("b", 1), ("c", 2)),
371+
Seq(("a", 4), ("b", 2), ("c", 3)),
372+
Seq(("a", 5), ("b", 3), ("c", 3)),
373+
Seq(("a", 6), ("b", 3), ("c", 3)),
374+
Seq(("a", 6), ("b", 3), ("c", 3))
375+
)
376+
377+
val updateStateOperation = (s: DStream[String]) => {
378+
val initialRDD = s.context.sparkContext.makeRDD(initial)
379+
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
380+
Some(values.sum + state.getOrElse(0))
381+
}
382+
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc,
383+
new HashPartitioner (numInputPartitions), initialRDD)
384+
}
385+
386+
testOperation(inputData, updateStateOperation, outputData, true)
387+
}
388+
354389
test("updateStateByKey - with initial value RDD") {
355390
val initial = Seq(("a", 1), ("c", 2))
356391

0 commit comments

Comments
 (0)