Skip to content

Commit f0a3121

Browse files
author
Davies Liu
committed
track life cycle of broadcast
1 parent f99b2e1 commit f0a3121

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

python/pyspark/rdd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,12 +2247,12 @@ def _jrdd(self):
22472247

22482248
command = (self.func, profiler, self._prev_jrdd_deserializer,
22492249
self._jrdd_deserializer)
2250-
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self.ctx, command)
2250+
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
22512251
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
2252-
bytearray(pickled_command),
2252+
bytearray(pickled_cmd),
22532253
env, includes, self.preservesPartitioning,
22542254
self.ctx.pythonExec,
2255-
broadcast_vars, self.ctx._javaAccumulator)
2255+
bvars, self.ctx._javaAccumulator)
22562256
self._jrdd_val = python_rdd.asJavaRDD()
22572257

22582258
if profiler:

python/pyspark/sql.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,13 +1276,13 @@ def registerFunction(self, name, f, returnType=StringType()):
12761276
func = lambda _, it: imap(lambda x: f(*x), it)
12771277
ser = AutoBatchedSerializer(PickleSerializer())
12781278
command = (func, None, ser, ser)
1279-
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self._sc, command)
1279+
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
12801280
self._ssql_ctx.udf().registerPython(name,
1281-
bytearray(pickled_command),
1281+
bytearray(pickled_cmd),
12821282
env,
12831283
includes,
12841284
self._sc.pythonExec,
1285-
broadcast_vars,
1285+
bvars,
12861286
self._sc._javaAccumulator,
12871287
returnType.json())
12881288

@@ -2540,6 +2540,7 @@ class UserDefinedFunction(object):
25402540
def __init__(self, func, returnType):
25412541
self.func = func
25422542
self.returnType = returnType
2543+
self._broadcast = None
25432544
self._judf = self._create_judf()
25442545

25452546
def _create_judf(self):
@@ -2548,13 +2549,18 @@ def _create_judf(self):
25482549
ser = AutoBatchedSerializer(PickleSerializer())
25492550
command = (func, None, ser, ser)
25502551
sc = SparkContext._active_spark_context
2551-
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
2552+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
25522553
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
25532554
jdt = ssql_ctx.parseDataType(self.returnType.json())
25542555
judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes,
25552556
sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt)
25562557
return judf
25572558

2559+
def __del__(self):
2560+
if self._broadcast is not None:
2561+
self._broadcast.unpersist()
2562+
self._broadcast = None
2563+
25582564
def __call__(self, *cols):
25592565
sc = SparkContext._active_spark_context
25602566
jcols = ListConverter().convert([_to_java_column(c) for c in cols],

0 commit comments

Comments
 (0)