Skip to content

Commit 8466916

Browse files
committed
support checkpoint
1 parent 9a16bd1 commit 8466916

File tree

4 files changed

+101
-23
lines changed

4 files changed

+101
-23
lines changed

python/pyspark/streaming/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from py4j.java_gateway import java_import
2020

2121
from pyspark import RDD
22-
from pyspark.serializers import UTF8Deserializer
22+
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
2323
from pyspark.context import SparkContext
2424
from pyspark.storagelevel import StorageLevel
2525
from pyspark.streaming.dstream import DStream
26-
from pyspark.streaming.util import RDDFunction
26+
from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer
2727

2828
__all__ = ["StreamingContext"]
2929

@@ -100,6 +100,9 @@ def _initialize_context(self, sc, duration):
100100
java_import(self._jvm, "org.apache.spark.streaming.*")
101101
java_import(self._jvm, "org.apache.spark.streaming.api.java.*")
102102
java_import(self._jvm, "org.apache.spark.streaming.api.python.*")
103+
# register serializer for RDDFunction
104+
ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer())
105+
self._jvm.PythonDStream.registerSerializer(ser)
103106
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
104107

105108
def _jduration(self, seconds):

python/pyspark/streaming/util.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from datetime import datetime
19+
import traceback
1920

2021
from pyspark.rdd import RDD
2122

@@ -47,7 +48,6 @@ def call(self, milliseconds, jrdds):
4748
if r:
4849
return r._jrdd
4950
except Exception:
50-
import traceback
5151
traceback.print_exc()
5252

5353
def __repr__(self):
@@ -57,6 +57,32 @@ class Java:
5757
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
5858

5959

60+
class RDDFunctionSerializer(object):
61+
def __init__(self, ctx, serializer):
62+
self.ctx = ctx
63+
self.serializer = serializer
64+
65+
def dumps(self, id):
66+
try:
67+
func = self.ctx._gateway.gateway_property.pool[id]
68+
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
69+
except Exception:
70+
traceback.print_exc()
71+
72+
def loads(self, bytes):
73+
try:
74+
f, deserializers = self.serializer.loads(str(bytes))
75+
return RDDFunction(self.ctx, f, *deserializers)
76+
except Exception:
77+
traceback.print_exc()
78+
79+
def __repr__(self):
80+
return "RDDFunctionSerializer(%s)" % self.serializer
81+
82+
class Java:
83+
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer']
84+
85+
6086
def rddToFileName(prefix, suffix, time):
6187
"""
6288
Return string prefix-time(.suffix)

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ class StreamingContext private[streaming] (
413413
dstreams: Seq[DStream[_]],
414414
transformFunc: (Seq[RDD[_]], Time) => RDD[T]
415415
): DStream[T] = {
416-
new TransformedDStream[T](dstreams, transformFunc)
416+
new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
417417
}
418418

419419
/** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.streaming.api.python
1919

20+
import java.io.{ObjectInputStream, ObjectOutputStream}
21+
import java.lang.reflect.Proxy
2022
import java.util.{ArrayList => JArrayList, List => JList}
2123
import scala.collection.JavaConversions._
2224
import scala.collection.JavaConverters._
23-
import scala.collection.mutable
2425

2526
import org.apache.spark.api.java._
2627
import org.apache.spark.api.python._
@@ -35,14 +36,14 @@ import org.apache.spark.streaming.api.java._
3536
* Interface for Python callback function with three arguments
3637
*/
3738
private[python] trait PythonRDDFunction {
38-
// callback in Python
3939
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
4040
}
4141

4242
/**
4343
* Wrapper for PythonRDDFunction
44+
* TODO: support checkpoint
4445
*/
45-
private[python] class RDDFunction(pfunc: PythonRDDFunction)
46+
private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
4647
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable {
4748

4849
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
@@ -58,30 +59,62 @@ private[python] class RDDFunction(pfunc: PythonRDDFunction)
5859
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
5960
pfunc.call(time.milliseconds, rdds)
6061
}
61-
}
6262

63+
private def writeObject(out: ObjectOutputStream): Unit = {
64+
assert(PythonDStream.serializer != null, "Serializer has not been registered!")
65+
val bytes = PythonDStream.serializer.serialize(pfunc)
66+
out.writeInt(bytes.length)
67+
out.write(bytes)
68+
}
69+
70+
private def readObject(in: ObjectInputStream): Unit = {
71+
assert(PythonDStream.serializer != null, "Serializer has not been registered!")
72+
val length = in.readInt()
73+
val bytes = new Array[Byte](length)
74+
in.readFully(bytes)
75+
pfunc = PythonDStream.serializer.deserialize(bytes)
76+
}
77+
}
6378

6479
/**
65-
* Base class for PythonDStream with some common methods
80+
* Inferface for Python Serializer to serialize PythonRDDFunction
6681
*/
67-
private[python]
68-
abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction)
69-
extends DStream[Array[Byte]] (parent.ssc) {
70-
71-
val func = new RDDFunction(pfunc)
72-
73-
override def dependencies = List(parent)
82+
private[python] trait PythonRDDFunctionSerializer {
83+
def dumps(id: String): Array[Byte] //
84+
def loads(bytes: Array[Byte]): PythonRDDFunction
85+
}
7486

75-
override def slideDuration: Duration = parent.slideDuration
87+
/**
88+
* Wrapper for PythonRDDFunctionSerializer
89+
*/
90+
private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
91+
def serialize(func: PythonRDDFunction): Array[Byte] = {
92+
// get the id of PythonRDDFunction in py4j
93+
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
94+
val f = h.getClass().getDeclaredField("id");
95+
f.setAccessible(true);
96+
val id = f.get(h).asInstanceOf[String];
97+
pser.dumps(id)
98+
}
7699

77-
val asJavaDStream = JavaDStream.fromDStream(this)
100+
def deserialize(bytes: Array[Byte]): PythonRDDFunction = {
101+
pser.loads(bytes)
102+
}
78103
}
79104

80105
/**
81106
* Helper functions
82107
*/
83108
private[python] object PythonDStream {
84109

110+
// A serializer in Python, used to serialize PythonRDDFunction
111+
var serializer: RDDFunctionSerializer = _
112+
113+
// Register a serializer from Python, should be called during initialization
114+
def registerSerializer(ser: PythonRDDFunctionSerializer) = {
115+
serializer = new RDDFunctionSerializer(ser)
116+
}
117+
85118
// convert Option[RDD[_]] to JavaRDD, handle null gracefully
86119
def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
87120
if (rdd.isDefined) {
@@ -123,14 +156,30 @@ private[python] object PythonDStream {
123156
}
124157
}
125158

159+
/**
160+
* Base class for PythonDStream with some common methods
161+
*/
162+
private[python]
163+
abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction)
164+
extends DStream[Array[Byte]] (parent.ssc) {
165+
166+
val func = new RDDFunction(pfunc)
167+
168+
override def dependencies = List(parent)
169+
170+
override def slideDuration: Duration = parent.slideDuration
171+
172+
val asJavaDStream = JavaDStream.fromDStream(this)
173+
}
174+
126175
/**
127176
* Transformed DStream in Python.
128177
*
129178
* If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it
130179
* as an template for future use, this can reduce the Python callbacks.
131180
*/
132181
private[python]
133-
class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction,
182+
class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction,
134183
var reuse: Boolean = false)
135184
extends PythonDStream(parent, pfunc) {
136185

@@ -170,7 +219,7 @@ class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction,
170219
*/
171220
private[python]
172221
class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
173-
pfunc: PythonRDDFunction)
222+
@transient pfunc: PythonRDDFunction)
174223
extends DStream[Array[Byte]] (parent.ssc) {
175224

176225
val func = new RDDFunction(pfunc)
@@ -190,7 +239,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
190239
* similar to StateDStream
191240
*/
192241
private[python]
193-
class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction)
242+
class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction)
194243
extends PythonDStream(parent, reduceFunc) {
195244

196245
super.persist(StorageLevel.MEMORY_ONLY)
@@ -212,8 +261,8 @@ class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunc
212261
*/
213262
private[python]
214263
class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
215-
preduceFunc: PythonRDDFunction,
216-
pinvReduceFunc: PythonRDDFunction,
264+
@transient preduceFunc: PythonRDDFunction,
265+
@transient pinvReduceFunc: PythonRDDFunction,
217266
_windowDuration: Duration,
218267
_slideDuration: Duration
219268
) extends PythonStateDStream(parent, preduceFunc) {

0 commit comments

Comments
 (0)