Skip to content

Commit f99b2e1

Browse files
author
Davies Liu
committed
address comments
1 parent 462b334 commit f99b2e1

File tree

2 files changed

+36
-59
lines changed

2 files changed

+36
-59
lines changed

python/pyspark/rdd.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,25 @@ def toLocalIterator(self):
21622162
yield row
21632163

21642164

2165+
def _prepare_for_python_RDD(sc, command, obj=None):
2166+
# the serialized command will be compressed by broadcast
2167+
ser = CloudPickleSerializer()
2168+
pickled_command = ser.dumps(command)
2169+
if len(pickled_command) > (1 << 20): # 1M
2170+
broadcast = sc.broadcast(pickled_command)
2171+
pickled_command = ser.dumps(broadcast)
2172+
# tracking the life cycle by obj
2173+
if obj is not None:
2174+
obj._broadcast = broadcast
2175+
broadcast_vars = ListConverter().convert(
2176+
[x._jbroadcast for x in sc._pickled_broadcast_vars],
2177+
sc._gateway._gateway_client)
2178+
sc._pickled_broadcast_vars.clear()
2179+
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
2180+
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
2181+
return pickled_command, broadcast_vars, env, includes
2182+
2183+
21652184
class PipelinedRDD(RDD):
21662185

21672186
"""
@@ -2228,20 +2247,7 @@ def _jrdd(self):
22282247

22292248
command = (self.func, profiler, self._prev_jrdd_deserializer,
22302249
self._jrdd_deserializer)
2231-
# the serialized command will be compressed by broadcast
2232-
ser = CloudPickleSerializer()
2233-
pickled_command = ser.dumps(command)
2234-
if len(pickled_command) > (1 << 20): # 1M
2235-
self._broadcast = self.ctx.broadcast(pickled_command)
2236-
pickled_command = ser.dumps(self._broadcast)
2237-
broadcast_vars = ListConverter().convert(
2238-
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
2239-
self.ctx._gateway._gateway_client)
2240-
self.ctx._pickled_broadcast_vars.clear()
2241-
env = MapConverter().convert(self.ctx.environment,
2242-
self.ctx._gateway._gateway_client)
2243-
includes = ListConverter().convert(self.ctx._python_includes,
2244-
self.ctx._gateway._gateway_client)
2250+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self.ctx, command)
22452251
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
22462252
bytearray(pickled_command),
22472253
env, includes, self.preservesPartitioning,

python/pyspark/sql.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from py4j.java_collections import ListConverter, MapConverter
5252

5353
from pyspark.context import SparkContext
54-
from pyspark.rdd import RDD
54+
from pyspark.rdd import RDD, _prepare_for_python_RDD
5555
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
5656
CloudPickleSerializer, UTF8Deserializer
5757
from pyspark.storagelevel import StorageLevel
@@ -1274,22 +1274,9 @@ def registerFunction(self, name, f, returnType=StringType()):
12741274
[Row(c0=4)]
12751275
"""
12761276
func = lambda _, it: imap(lambda x: f(*x), it)
1277-
command = (func, None,
1278-
AutoBatchedSerializer(PickleSerializer()),
1279-
AutoBatchedSerializer(PickleSerializer()))
1280-
ser = CloudPickleSerializer()
1281-
pickled_command = ser.dumps(command)
1282-
if len(pickled_command) > (1 << 20): # 1M
1283-
broadcast = self._sc.broadcast(pickled_command)
1284-
pickled_command = ser.dumps(broadcast)
1285-
broadcast_vars = ListConverter().convert(
1286-
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
1287-
self._sc._gateway._gateway_client)
1288-
self._sc._pickled_broadcast_vars.clear()
1289-
env = MapConverter().convert(self._sc.environment,
1290-
self._sc._gateway._gateway_client)
1291-
includes = ListConverter().convert(self._sc._python_includes,
1292-
self._sc._gateway._gateway_client)
1277+
ser = AutoBatchedSerializer(PickleSerializer())
1278+
command = (func, None, ser, ser)
1279+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self._sc, command)
12931280
self._ssql_ctx.udf().registerPython(name,
12941281
bytearray(pickled_command),
12951282
env,
@@ -2187,7 +2174,7 @@ def select(self, *cols):
21872174
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
21882175
>>> df.select('name', 'age').collect()
21892176
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
2190-
>>> df.select(df.name, (df.age + 10).As('age')).collect()
2177+
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
21912178
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
21922179
"""
21932180
if not cols:
@@ -2268,7 +2255,7 @@ def addColumn(self, colName, col):
22682255
>>> df.addColumn('age2', df.age + 2).collect()
22692256
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
22702257
"""
2271-
return self.select('*', col.As(colName))
2258+
return self.select('*', col.alias(colName))
22722259

22732260

22742261
# Having SchemaRDD for backward compatibility (for docs)
@@ -2509,24 +2496,20 @@ def substr(self, startPos, length):
25092496
isNull = _unary_op("isNull", "True if the current expression is null.")
25102497
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
25112498

2512-
# `as` is keyword
25132499
def alias(self, alias):
25142500
"""Return a alias for this column
25152501
2516-
>>> df.age.As("age2").collect()
2517-
[Row(age2=2), Row(age2=5)]
25182502
>>> df.age.alias("age2").collect()
25192503
[Row(age2=2), Row(age2=5)]
25202504
"""
25212505
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
2522-
As = alias
25232506

25242507
def cast(self, dataType):
25252508
""" Convert the column into type `dataType`
25262509
2527-
>>> df.select(df.age.cast("string").As('ages')).collect()
2510+
>>> df.select(df.age.cast("string").alias('ages')).collect()
25282511
[Row(ages=u'2'), Row(ages=u'5')]
2529-
>>> df.select(df.age.cast(StringType()).As('ages')).collect()
2512+
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
25302513
[Row(ages=u'2'), Row(ages=u'5')]
25312514
"""
25322515
if self.sql_ctx is None:
@@ -2560,24 +2543,12 @@ def __init__(self, func, returnType):
25602543
self._judf = self._create_judf()
25612544

25622545
def _create_judf(self):
2563-
f = self.func
2564-
sc = SparkContext._active_spark_context
2565-
# TODO(davies): refactor
2546+
f = self.func # put it in closure `func`
25662547
func = lambda _, it: imap(lambda x: f(*x), it)
2567-
command = (func, None,
2568-
AutoBatchedSerializer(PickleSerializer()),
2569-
AutoBatchedSerializer(PickleSerializer()))
2570-
ser = CloudPickleSerializer()
2571-
pickled_command = ser.dumps(command)
2572-
if len(pickled_command) > (1 << 20): # 1M
2573-
broadcast = sc.broadcast(pickled_command)
2574-
pickled_command = ser.dumps(broadcast)
2575-
broadcast_vars = ListConverter().convert(
2576-
[x._jbroadcast for x in sc._pickled_broadcast_vars],
2577-
sc._gateway._gateway_client)
2578-
sc._pickled_broadcast_vars.clear()
2579-
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
2580-
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
2548+
ser = AutoBatchedSerializer(PickleSerializer())
2549+
command = (func, None, ser, ser)
2550+
sc = SparkContext._active_spark_context
2551+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
25812552
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
25822553
jdt = ssql_ctx.parseDataType(self.returnType.json())
25832554
judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes,
@@ -2625,7 +2596,7 @@ def countDistinct(col, *cols):
26252596
""" Return a new Column for distinct count of (col, *cols)
26262597
26272598
>>> from pyspark.sql import Dsl
2628-
>>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
2599+
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
26292600
[Row(c=2)]
26302601
"""
26312602
sc = SparkContext._active_spark_context
@@ -2640,7 +2611,7 @@ def approxCountDistinct(col, rsd=None):
26402611
""" Return a new Column for approxiate distinct count of (col, *cols)
26412612
26422613
>>> from pyspark.sql import Dsl
2643-
>>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
2614+
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
26442615
[Row(c=2)]
26452616
"""
26462617
sc = SparkContext._active_spark_context
@@ -2655,7 +2626,7 @@ def udf(f, returnType=StringType()):
26552626
"""Create a user defined function (UDF)
26562627
26572628
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2658-
>>> df.select(slen(df.name).As('slen')).collect()
2629+
>>> df.select(slen(df.name).alias('slen')).collect()
26592630
[Row(slen=5), Row(slen=3)]
26602631
"""
26612632
return UserDefinedFunction(f, returnType)

0 commit comments

Comments
 (0)