@@ -1276,13 +1276,13 @@ def registerFunction(self, name, f, returnType=StringType()):
1276
1276
func = lambda _ , it : imap (lambda x : f (* x ), it )
1277
1277
ser = AutoBatchedSerializer (PickleSerializer ())
1278
1278
command = (func , None , ser , ser )
1279
- pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (self ._sc , command )
1279
+ pickled_cmd , bvars , env , includes = _prepare_for_python_RDD (self ._sc , command , self )
1280
1280
self ._ssql_ctx .udf ().registerPython (name ,
1281
- bytearray (pickled_command ),
1281
+ bytearray (pickled_cmd ),
1282
1282
env ,
1283
1283
includes ,
1284
1284
self ._sc .pythonExec ,
1285
- broadcast_vars ,
1285
+ bvars ,
1286
1286
self ._sc ._javaAccumulator ,
1287
1287
returnType .json ())
1288
1288
@@ -2540,6 +2540,7 @@ class UserDefinedFunction(object):
2540
2540
def __init__ (self , func , returnType ):
2541
2541
self .func = func
2542
2542
self .returnType = returnType
2543
+ self ._broadcast = None
2543
2544
self ._judf = self ._create_judf ()
2544
2545
2545
2546
def _create_judf (self ):
@@ -2548,13 +2549,18 @@ def _create_judf(self):
2548
2549
ser = AutoBatchedSerializer (PickleSerializer ())
2549
2550
command = (func , None , ser , ser )
2550
2551
sc = SparkContext ._active_spark_context
2551
- pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command )
2552
+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command , self )
2552
2553
ssql_ctx = sc ._jvm .SQLContext (sc ._jsc .sc ())
2553
2554
jdt = ssql_ctx .parseDataType (self .returnType .json ())
2554
2555
judf = sc ._jvm .Dsl .pythonUDF (f .__name__ , bytearray (pickled_command ), env , includes ,
2555
2556
sc .pythonExec , broadcast_vars , sc ._javaAccumulator , jdt )
2556
2557
return judf
2557
2558
2559
+ def __del__ (self ):
2560
+ if self ._broadcast is not None :
2561
+ self ._broadcast .unpersist ()
2562
+ self ._broadcast = None
2563
+
2558
2564
def __call__ (self , * cols ):
2559
2565
sc = SparkContext ._active_spark_context
2560
2566
jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
0 commit comments