Skip to content

Commit fdd7db3

Browse files
committed
Adding support of initial value for state update.
SPARK-3660 : Initial RDD for updateStateByKey transformation
1 parent 8d22dbb commit fdd7db3

File tree

3 files changed

+148
-32
lines changed

3 files changed

+148
-32
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.streaming
19+
20+
import org.apache.spark.{HashPartitioner, SparkConf}
21+
import org.apache.spark.streaming._
22+
import org.apache.spark.streaming.StreamingContext._
23+
24+
/**
25+
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
26+
* second starting with initial value of word count.
27+
* Usage: StatefulNetworkWordCountWithInitial <hostname> <port>
28+
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
29+
* data.
30+
*
31+
* To run this on your local machine, you need to first run a Netcat server
32+
* `$ nc -lk 9999`
33+
* and then run the example
34+
* `$ bin/run-example
35+
* org.apache.spark.examples.streaming.StatefulNetworkWordCountWithInitial localhost 9999`
36+
*/
37+
object StatefulNetworkWordCountWithInitial {
38+
def main(args: Array[String]) {
39+
if (args.length < 2) {
40+
System.err.println("Usage: StatefulNetworkWordCountWithInitial <hostname> <port>")
41+
System.exit(1)
42+
}
43+
44+
StreamingExamples.setStreamingLogLevels()
45+
46+
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
47+
val currentCount = values.sum
48+
49+
val previousCount = state.getOrElse(0)
50+
51+
Some(currentCount + previousCount)
52+
}
53+
54+
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
55+
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
56+
}
57+
58+
val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCountWithInitial")
59+
// Create the context with a 1 second batch size
60+
val ssc = new StreamingContext(sparkConf, Seconds(1))
61+
ssc.checkpoint(".")
62+
63+
// Initial RDD input to updateStateByKey
64+
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
65+
66+
// Create a NetworkInputDStream on target ip:port and count the
67+
// words in input stream of \n delimited test (eg. generated by 'nc')
68+
val lines = ssc.socketTextStream(args(0), args(1).toInt)
69+
val words = lines.flatMap(_.split(" "))
70+
val wordDstream = words.map(x => (x, 1))
71+
72+
// Update the cumulative count using updateStateByKey
73+
// This will give a Dstream made of state (which is the cumulative count of the words)
74+
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
75+
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
76+
stateDstream.print()
77+
ssc.start()
78+
ssc.awaitTermination()
79+
}
80+
}

streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,31 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
394394
updateStateByKey(newUpdateFunc, partitioner, true)
395395
}
396396

397+
/**
398+
* Return a new "state" DStream where the state for each key is updated by applying
399+
* the given function on the previous state of the key and the new values of each key.
400+
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
401+
* @param updateFunc State update function. If `this` function returns None, then
402+
* corresponding state key-value pair will be eliminated. Note, that
403+
* this function may generate a different a tuple with a different key
404+
* than the input key. It is up to the developer to decide whether to
405+
* remember the partitioner despite the key being changed.
406+
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
407+
* DStream
408+
* @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
409+
* @param initial state value of each key.
410+
* @tparam S State type
411+
*/
412+
def updateStateByKey[S: ClassTag](
413+
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
414+
partitioner: Partitioner,
415+
rememberPartitioner: Boolean,
416+
initial : RDD[(K, S)]
417+
): DStream[(K, S)] = {
418+
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
419+
rememberPartitioner, Some(initial))
420+
}
421+
397422
/**
398423
* Return a new "state" DStream where the state for each key is updated by applying
399424
* the given function on the previous state of the key and the new values of each key.
@@ -413,7 +438,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
413438
partitioner: Partitioner,
414439
rememberPartitioner: Boolean
415440
): DStream[(K, S)] = {
416-
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
441+
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
417442
}
418443

419444
/**

streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
3030
parent: DStream[(K, V)],
3131
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
3232
partitioner: Partitioner,
33-
preservePartitioning: Boolean
33+
preservePartitioning: Boolean,
34+
initial : Option[RDD[(K, S)]]
3435
) extends DStream[(K, S)](parent.ssc) {
3536

3637
super.persist(StorageLevel.MEMORY_ONLY_SER)
@@ -41,6 +42,28 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
4142

4243
override val mustCheckpoint = true
4344

45+
private [this] def computeUsingPreviousRDD (
46+
parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
47+
// Define the function for the mapPartition operation on cogrouped RDD;
48+
// first map the cogrouped tuple to tuples of required type,
49+
// and then apply the update function
50+
val updateFuncLocal = updateFunc
51+
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
52+
val i = iterator.map(t => {
53+
val itr = t._2._2.iterator
54+
val headOption = itr.hasNext match {
55+
case true => Some(itr.next())
56+
case false => None
57+
}
58+
(t._1, t._2._1.toSeq, headOption)
59+
})
60+
updateFuncLocal(i)
61+
}
62+
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
63+
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
64+
Some(stateRDD)
65+
}
66+
4467
override def compute(validTime: Time): Option[RDD[(K, S)]] = {
4568

4669
// Try to get the previous state RDD
@@ -51,25 +74,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
5174
// Try to get the parent RDD
5275
parent.getOrCompute(validTime) match {
5376
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
54-
55-
// Define the function for the mapPartition operation on cogrouped RDD;
56-
// first map the cogrouped tuple to tuples of required type,
57-
// and then apply the update function
58-
val updateFuncLocal = updateFunc
59-
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
60-
val i = iterator.map(t => {
61-
val itr = t._2._2.iterator
62-
val headOption = itr.hasNext match {
63-
case true => Some(itr.next())
64-
case false => None
65-
}
66-
(t._1, t._2._1.toSeq, headOption)
67-
})
68-
updateFuncLocal(i)
69-
}
70-
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
71-
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
72-
Some(stateRDD)
77+
computeUsingPreviousRDD (parentRDD, prevStateRDD)
7378
}
7479
case None => { // If parent RDD does not exist
7580

@@ -90,19 +95,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
9095
// Try to get the parent RDD
9196
parent.getOrCompute(validTime) match {
9297
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
98+
initial match {
99+
case None => {
100+
// Define the function for the mapPartition operation on grouped RDD;
101+
// first map the grouped tuple to tuples of required type,
102+
// and then apply the update function
103+
val updateFuncLocal = updateFunc
104+
val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
105+
updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
106+
}
93107

94-
// Define the function for the mapPartition operation on grouped RDD;
95-
// first map the grouped tuple to tuples of required type,
96-
// and then apply the update function
97-
val updateFuncLocal = updateFunc
98-
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
99-
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
108+
val groupedRDD = parentRDD.groupByKey (partitioner)
109+
val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
110+
// logDebug("Generating state RDD for time " + validTime + " (first)")
111+
Some (sessionRDD)
112+
}
113+
case Some (initialRDD) => {
114+
computeUsingPreviousRDD(parentRDD, initialRDD)
115+
}
100116
}
101-
102-
val groupedRDD = parentRDD.groupByKey(partitioner)
103-
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
104-
// logDebug("Generating state RDD for time " + validTime + " (first)")
105-
Some(sessionRDD)
106117
}
107118
case None => { // If parent RDD does not exist, then nothing to do!
108119
// logDebug("Not generating state RDD (no previous state, no parent)")

0 commit comments

Comments
 (0)