Skip to content

Commit 6f0da2f

Browse files
committed
recover from checkpoint
1 parent fa7261b commit 6f0da2f

File tree

9 files changed

+136
-38
lines changed

9 files changed

+136
-38
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD
4242
import org.apache.spark.util.Utils
4343

4444
private[spark] class PythonRDD(
45-
parent: RDD[_],
45+
@transient parent: RDD[_],
4646
command: Array[Byte],
4747
envVars: JMap[String, String],
4848
pythonIncludes: JList[String],
@@ -61,9 +61,9 @@ private[spark] class PythonRDD(
6161
val bufferSize = conf.getInt("spark.buffer.size", 65536)
6262
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
6363

64-
override def getPartitions = parent.partitions
64+
override def getPartitions = firstParent.partitions
6565

66-
override val partitioner = if (preservePartitoning) parent.partitioner else None
66+
override val partitioner = if (preservePartitoning) firstParent.partitioner else None
6767

6868
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
6969
val startTime = System.currentTimeMillis
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
241241
dataOut.writeInt(command.length)
242242
dataOut.write(command)
243243
// Data values
244-
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
244+
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
245245
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
246246
dataOut.flush()
247247
} catch {

core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
8484

8585
private[spark] class ParallelCollectionRDD[T: ClassTag](
8686
@transient sc: SparkContext,
87-
@transient data: Seq[T],
87+
data: Seq[T],
8888
numSlices: Int,
8989
locationPrefs: Map[Int, Seq[String]])
9090
extends RDD[T](sc, Nil) {

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ abstract class RDD[T: ClassTag](
8282
def this(@transient oneParent: RDD[_]) =
8383
this(oneParent.context , List(new OneToOneDependency(oneParent)))
8484

85+
// setContext after loading from checkpointing
86+
private[spark] def setContext(s: SparkContext) = {
87+
if (sc != null && sc != s) {
88+
throw new SparkException("Context is already set in " + this + ", cannot set it again")
89+
}
90+
sc = s
91+
}
92+
8593
private[spark] def conf = sc.conf
8694
// =======================================================================
8795
// Methods that should be implemented by subclasses of RDD

python/pyspark/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class SparkContext(object):
6868

6969
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
7070
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
71-
gateway=None):
71+
gateway=None, jsc=None):
7272
"""
7373
Create a new SparkContext. At least the master and app name should be set,
7474
either through the named parameters here or through C{conf}.
@@ -103,14 +103,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
103103
SparkContext._ensure_initialized(self, gateway=gateway)
104104
try:
105105
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
106-
conf)
106+
conf, jsc)
107107
except:
108108
# If an error occurs, clean up in order to allow future SparkContext creation:
109109
self.stop()
110110
raise
111111

112112
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
113-
conf):
113+
conf, jsc):
114114
self.environment = environment or {}
115115
self._conf = conf or SparkConf(_jvm=self._jvm)
116116
self._batchSize = batchSize # -1 represents an unlimited batch size
@@ -151,7 +151,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
151151
self.environment[varName] = v
152152

153153
# Create the Java SparkContext through Py4J
154-
self._jsc = self._initialize_context(self._conf._jconf)
154+
self._jsc = jsc or self._initialize_context(self._conf._jconf)
155155

156156
# Create a single Accumulator in Java that we'll send all our updates through;
157157
# they will be passed back to us through a TCP server

python/pyspark/streaming/context.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import os
18+
import sys
1719

1820
from py4j.java_collections import ListConverter
1921
from py4j.java_gateway import java_import
2022

21-
from pyspark import RDD
23+
from pyspark import RDD, SparkConf
2224
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
2325
from pyspark.context import SparkContext
2426
from pyspark.storagelevel import StorageLevel
@@ -75,41 +77,81 @@ class StreamingContext(object):
7577
respectively. `context.awaitTransformation()` allows the current thread
7678
to wait for the termination of the context by `stop()` or by an exception.
7779
"""
80+
_transformerSerializer = None
7881

79-
def __init__(self, sparkContext, duration):
82+
def __init__(self, sparkContext, duration=None, jssc=None):
8083
"""
8184
Create a new StreamingContext.
8285
8386
@param sparkContext: L{SparkContext} object.
8487
@param duration: number of seconds.
8588
"""
89+
8690
self._sc = sparkContext
8791
self._jvm = self._sc._jvm
88-
self._start_callback_server()
89-
self._jssc = self._initialize_context(self._sc, duration)
92+
self._jssc = jssc or self._initialize_context(self._sc, duration)
93+
94+
def _initialize_context(self, sc, duration):
95+
self._ensure_initialized()
96+
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
97+
98+
def _jduration(self, seconds):
99+
"""
100+
Create Duration object given number of seconds
101+
"""
102+
return self._jvm.Duration(int(seconds * 1000))
90103

91-
def _start_callback_server(self):
92-
gw = self._sc._gateway
104+
@classmethod
105+
def _ensure_initialized(cls):
106+
SparkContext._ensure_initialized()
107+
gw = SparkContext._gateway
108+
# start callback server
93109
# getattr will fallback to JVM
94110
if "_callback_server" not in gw.__dict__:
95111
_daemonize_callback_server()
96112
gw._start_callback_server(gw._python_proxy_port)
97-
gw._python_proxy_port = gw._callback_server.port # update port with real port
98113

99-
def _initialize_context(self, sc, duration):
100-
java_import(self._jvm, "org.apache.spark.streaming.*")
101-
java_import(self._jvm, "org.apache.spark.streaming.api.java.*")
102-
java_import(self._jvm, "org.apache.spark.streaming.api.python.*")
114+
java_import(gw.jvm, "org.apache.spark.streaming.*")
115+
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
116+
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
103117
# register serializer for RDDFunction
104-
ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer())
105-
self._jvm.PythonDStream.registerSerializer(ser)
106-
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
118+
# it happens before creating SparkContext when loading from checkpointing
119+
cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context,
120+
CloudPickleSerializer(), gw)
121+
gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer)
107122

108-
def _jduration(self, seconds):
123+
@classmethod
124+
def getOrCreate(cls, path, setupFunc):
109125
"""
110-
Create Duration object given number of seconds
126+
Get the StreamingContext from checkpoint file at `path`, or setup
127+
it by `setupFunc`.
128+
129+
:param path: directory of checkpoint
130+
:param setupFunc: a function used to create StreamingContext and
131+
setup DStreams.
132+
:return: a StreamingContext
111133
"""
112-
return self._jvm.Duration(int(seconds * 1000))
134+
if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path):
135+
ssc = setupFunc()
136+
ssc.checkpoint(path)
137+
return ssc
138+
139+
cls._ensure_initialized()
140+
gw = SparkContext._gateway
141+
142+
try:
143+
jssc = gw.jvm.JavaStreamingContext(path)
144+
except Exception:
145+
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
146+
raise
147+
148+
jsc = jssc.sparkContext()
149+
conf = SparkConf(_jconf=jsc.getConf())
150+
sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
151+
# update ctx in serializer
152+
SparkContext._active_spark_context = sc
153+
cls._transformerSerializer.ctx = sc
154+
return StreamingContext(sc, None, jssc)
113155

114156
@property
115157
def sparkContext(self):

python/pyspark/streaming/tests.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,5 +493,38 @@ def func(rdds):
493493
self.assertEqual([2, 3, 1], self._take(dstream, 3))
494494

495495

496+
class TestCheckpoint(PySparkStreamingTestCase):
497+
498+
def setUp(self):
499+
pass
500+
501+
def tearDown(self):
502+
pass
503+
504+
def test_get_or_create(self):
505+
result = [0]
506+
507+
def setup():
508+
conf = SparkConf().set("spark.default.parallelism", 1)
509+
sc = SparkContext(conf=conf)
510+
ssc = StreamingContext(sc, .2)
511+
rdd = sc.parallelize(range(10), 1)
512+
dstream = ssc.queueStream([rdd], default=rdd)
513+
result[0] = self._collect(dstream.countByWindow(1, .2))
514+
return ssc
515+
tmpd = tempfile.mkdtemp("test_streaming_cps")
516+
ssc = StreamingContext.getOrCreate(tmpd, setup)
517+
ssc.start()
518+
ssc.awaitTermination(4)
519+
ssc.stop()
520+
expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5
521+
self.assertEqual(expected, result[0][:10])
522+
523+
ssc = StreamingContext.getOrCreate(tmpd, setup)
524+
ssc.start()
525+
ssc.awaitTermination(2)
526+
ssc.stop()
527+
528+
496529
if __name__ == "__main__":
497530
unittest.main()

python/pyspark/streaming/util.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,31 @@
1818
from datetime import datetime
1919
import traceback
2020

21-
from pyspark.rdd import RDD
21+
from pyspark import SparkContext, RDD
2222

2323

2424
class RDDFunction(object):
2525
"""
2626
This class is for py4j callback.
2727
"""
28+
_emptyRDD = None
29+
2830
def __init__(self, ctx, func, *deserializers):
2931
self.ctx = ctx
3032
self.func = func
3133
self.deserializers = deserializers
32-
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
33-
if emptyRDD is None:
34-
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
35-
self.emptyRDD = emptyRDD
34+
35+
@property
36+
def emptyRDD(self):
37+
if self._emptyRDD is None and self.ctx:
38+
self._emptyRDD = self.ctx.parallelize([]).cache()
39+
return self._emptyRDD
3640

3741
def call(self, milliseconds, jrdds):
3842
try:
43+
if self.ctx is None:
44+
self.ctx = SparkContext._active_spark_context
45+
3946
# extend deserializers with the first one
4047
sers = self.deserializers
4148
if len(sers) < len(jrdds):
@@ -51,20 +58,21 @@ def call(self, milliseconds, jrdds):
5158
traceback.print_exc()
5259

5360
def __repr__(self):
54-
return "RDDFunction(%s)" % (str(self.func))
61+
return "RDDFunction(%s)" % self.func
5562

5663
class Java:
5764
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
5865

5966

6067
class RDDFunctionSerializer(object):
61-
def __init__(self, ctx, serializer):
68+
def __init__(self, ctx, serializer, gateway=None):
6269
self.ctx = ctx
6370
self.serializer = serializer
71+
self.gateway = gateway or self.ctx._gateway
6472

6573
def dumps(self, id):
6674
try:
67-
func = self.ctx._gateway.gateway_property.pool[id]
75+
func = self.gateway.gateway_property.pool[id]
6876
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
6977
except Exception:
7078
traceback.print_exc()

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
7777
}
7878

7979
/**
80-
* Inferface for Python Serializer to serialize PythonRDDFunction
80+
* Interface for Python Serializer to serialize PythonRDDFunction
8181
*/
8282
private[python] trait PythonRDDFunctionSerializer {
8383
def dumps(id: String): Array[Byte] //
@@ -91,9 +91,9 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
9191
def serialize(func: PythonRDDFunction): Array[Byte] = {
9292
// get the id of PythonRDDFunction in py4j
9393
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
94-
val f = h.getClass().getDeclaredField("id");
95-
f.setAccessible(true);
96-
val id = f.get(h).asInstanceOf[String];
94+
val f = h.getClass().getDeclaredField("id")
95+
f.setAccessible(true)
96+
val id = f.get(h).asInstanceOf[String]
9797
pser.dumps(id)
9898
}
9999

streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.streaming.dstream
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.rdd.RDD
2122
import org.apache.spark.rdd.UnionRDD
2223
import scala.collection.mutable.Queue
@@ -32,6 +33,12 @@ class QueueInputDStream[T: ClassTag](
3233
defaultRDD: RDD[T]
3334
) extends InputDStream[T](ssc) {
3435

36+
private[streaming] override def setContext(s: StreamingContext) {
37+
super.setContext(s)
38+
queue.map(_.setContext(s.sparkContext))
39+
defaultRDD.setContext(s.sparkContext)
40+
}
41+
3542
override def start() { }
3643

3744
override def stop() { }

0 commit comments

Comments
 (0)