Skip to content

Commit 6cfe341

Browse files
zsxwingdavies
authored andcommitted
[SPARK-12511] [PYSPARK] [STREAMING] Make sure PythonDStream.registerSerializer is called only once
There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. (py4j/py4j#184) Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when calling "registerSerializer". If we call "registerSerializer" twice, the second PythonProxyHandler will override the first one, then the first one will be GCed and trigger "PythonProxyHandler.finalize". To avoid that, we should not call"registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't be GCed. Author: Shixiong Zhu <[email protected]> Closes #10514 from zsxwing/SPARK-12511.
1 parent c26d174 commit 6cfe341

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

python/pyspark/streaming/context.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,28 @@ def _ensure_initialized(cls):
9898

9999
# register serializer for TransformFunction
100100
# it happens before creating SparkContext when loading from checkpointing
101-
cls._transformerSerializer = TransformFunctionSerializer(
102-
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
101+
if cls._transformerSerializer is None:
102+
transformer_serializer = TransformFunctionSerializer()
103+
transformer_serializer.init(
104+
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
105+
# SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM
106+
# There is an issue that Py4J's PythonProxyHandler.finalize blocks forever.
107+
# (https://github.com/bartdag/py4j/pull/184)
108+
#
109+
# Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when
110+
# calling "registerSerializer". If we call "registerSerializer" twice, the second
111+
# PythonProxyHandler will override the first one, then the first one will be GCed and
112+
# trigger "PythonProxyHandler.finalize". To avoid that, we should not call
113+
# "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't
114+
# be GCed.
115+
#
116+
# TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version.
117+
transformer_serializer.gateway.jvm.PythonDStream.registerSerializer(
118+
transformer_serializer)
119+
cls._transformerSerializer = transformer_serializer
120+
else:
121+
cls._transformerSerializer.init(
122+
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
103123

104124
@classmethod
105125
def getOrCreate(cls, checkpointPath, setupFunc):
@@ -116,16 +136,13 @@ def getOrCreate(cls, checkpointPath, setupFunc):
116136
gw = SparkContext._gateway
117137

118138
# Check whether valid checkpoint information exists in the given path
119-
if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty():
139+
ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
140+
if ssc_option.isEmpty():
120141
ssc = setupFunc()
121142
ssc.checkpoint(checkpointPath)
122143
return ssc
123144

124-
try:
125-
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
126-
except Exception:
127-
print("failed to load StreamingContext from checkpoint", file=sys.stderr)
128-
raise
145+
jssc = gw.jvm.JavaStreamingContext(ssc_option.get())
129146

130147
# If there is already an active instance of Python SparkContext use it, or create a new one
131148
if not SparkContext._active_spark_context:

python/pyspark/streaming/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ class TransformFunctionSerializer(object):
8989
it uses this class to invoke Python, which returns the serialized function
9090
as a byte array.
9191
"""
92-
def __init__(self, ctx, serializer, gateway=None):
92+
def init(self, ctx, serializer, gateway=None):
9393
self.ctx = ctx
9494
self.serializer = serializer
9595
self.gateway = gateway or self.ctx._gateway
96-
self.gateway.jvm.PythonDStream.registerSerializer(self)
9796
self.failure = None
9897

9998
def dumps(self, id):

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,3 +902,15 @@ object StreamingContext extends Logging {
902902
result
903903
}
904904
}
905+
906+
private class StreamingContextPythonHelper {
907+
908+
/**
909+
* This is a private method only for Python to implement `getOrCreate`.
910+
*/
911+
def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = {
912+
val checkpointOption = CheckpointReader.read(
913+
checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false)
914+
checkpointOption.map(new StreamingContext(null, _, null))
915+
}
916+
}

0 commit comments

Comments
 (0)