51
51
from py4j .java_collections import ListConverter , MapConverter
52
52
53
53
from pyspark .context import SparkContext
54
- from pyspark .rdd import RDD
54
+ from pyspark .rdd import RDD , _prepare_for_python_RDD
55
55
from pyspark .serializers import BatchedSerializer , AutoBatchedSerializer , PickleSerializer , \
56
56
CloudPickleSerializer , UTF8Deserializer
57
57
from pyspark .storagelevel import StorageLevel
@@ -1274,22 +1274,9 @@ def registerFunction(self, name, f, returnType=StringType()):
1274
1274
[Row(c0=4)]
1275
1275
"""
1276
1276
func = lambda _ , it : imap (lambda x : f (* x ), it )
1277
- command = (func , None ,
1278
- AutoBatchedSerializer (PickleSerializer ()),
1279
- AutoBatchedSerializer (PickleSerializer ()))
1280
- ser = CloudPickleSerializer ()
1281
- pickled_command = ser .dumps (command )
1282
- if len (pickled_command ) > (1 << 20 ): # 1M
1283
- broadcast = self ._sc .broadcast (pickled_command )
1284
- pickled_command = ser .dumps (broadcast )
1285
- broadcast_vars = ListConverter ().convert (
1286
- [x ._jbroadcast for x in self ._sc ._pickled_broadcast_vars ],
1287
- self ._sc ._gateway ._gateway_client )
1288
- self ._sc ._pickled_broadcast_vars .clear ()
1289
- env = MapConverter ().convert (self ._sc .environment ,
1290
- self ._sc ._gateway ._gateway_client )
1291
- includes = ListConverter ().convert (self ._sc ._python_includes ,
1292
- self ._sc ._gateway ._gateway_client )
1277
+ ser = AutoBatchedSerializer (PickleSerializer ())
1278
+ command = (func , None , ser , ser )
1279
+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (self ._sc , command )
1293
1280
self ._ssql_ctx .udf ().registerPython (name ,
1294
1281
bytearray (pickled_command ),
1295
1282
env ,
@@ -2187,7 +2174,7 @@ def select(self, *cols):
2187
2174
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
2188
2175
>>> df.select('name', 'age').collect()
2189
2176
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
2190
- >>> df.select(df.name, (df.age + 10).As ('age')).collect()
2177
+ >>> df.select(df.name, (df.age + 10).alias ('age')).collect()
2191
2178
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
2192
2179
"""
2193
2180
if not cols :
@@ -2268,7 +2255,7 @@ def addColumn(self, colName, col):
2268
2255
>>> df.addColumn('age2', df.age + 2).collect()
2269
2256
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
2270
2257
"""
2271
- return self .select ('*' , col .As (colName ))
2258
+ return self .select ('*' , col .alias (colName ))
2272
2259
2273
2260
2274
2261
# Having SchemaRDD for backward compatibility (for docs)
@@ -2509,24 +2496,20 @@ def substr(self, startPos, length):
2509
2496
isNull = _unary_op ("isNull" , "True if the current expression is null." )
2510
2497
isNotNull = _unary_op ("isNotNull" , "True if the current expression is not null." )
2511
2498
2512
- # `as` is keyword
2513
2499
def alias (self , alias ):
2514
2500
"""Return a alias for this column
2515
2501
2516
- >>> df.age.As("age2").collect()
2517
- [Row(age2=2), Row(age2=5)]
2518
2502
>>> df.age.alias("age2").collect()
2519
2503
[Row(age2=2), Row(age2=5)]
2520
2504
"""
2521
2505
return Column (getattr (self ._jc , "as" )(alias ), self .sql_ctx )
2522
- As = alias
2523
2506
2524
2507
def cast (self , dataType ):
2525
2508
""" Convert the column into type `dataType`
2526
2509
2527
- >>> df.select(df.age.cast("string").As ('ages')).collect()
2510
+ >>> df.select(df.age.cast("string").alias ('ages')).collect()
2528
2511
[Row(ages=u'2'), Row(ages=u'5')]
2529
- >>> df.select(df.age.cast(StringType()).As ('ages')).collect()
2512
+ >>> df.select(df.age.cast(StringType()).alias ('ages')).collect()
2530
2513
[Row(ages=u'2'), Row(ages=u'5')]
2531
2514
"""
2532
2515
if self .sql_ctx is None :
@@ -2560,24 +2543,12 @@ def __init__(self, func, returnType):
2560
2543
self ._judf = self ._create_judf ()
2561
2544
2562
2545
def _create_judf (self ):
2563
- f = self .func
2564
- sc = SparkContext ._active_spark_context
2565
- # TODO(davies): refactor
2546
+ f = self .func # put it in closure `func`
2566
2547
func = lambda _ , it : imap (lambda x : f (* x ), it )
2567
- command = (func , None ,
2568
- AutoBatchedSerializer (PickleSerializer ()),
2569
- AutoBatchedSerializer (PickleSerializer ()))
2570
- ser = CloudPickleSerializer ()
2571
- pickled_command = ser .dumps (command )
2572
- if len (pickled_command ) > (1 << 20 ): # 1M
2573
- broadcast = sc .broadcast (pickled_command )
2574
- pickled_command = ser .dumps (broadcast )
2575
- broadcast_vars = ListConverter ().convert (
2576
- [x ._jbroadcast for x in sc ._pickled_broadcast_vars ],
2577
- sc ._gateway ._gateway_client )
2578
- sc ._pickled_broadcast_vars .clear ()
2579
- env = MapConverter ().convert (sc .environment , sc ._gateway ._gateway_client )
2580
- includes = ListConverter ().convert (sc ._python_includes , sc ._gateway ._gateway_client )
2548
+ ser = AutoBatchedSerializer (PickleSerializer ())
2549
+ command = (func , None , ser , ser )
2550
+ sc = SparkContext ._active_spark_context
2551
+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command )
2581
2552
ssql_ctx = sc ._jvm .SQLContext (sc ._jsc .sc ())
2582
2553
jdt = ssql_ctx .parseDataType (self .returnType .json ())
2583
2554
judf = sc ._jvm .Dsl .pythonUDF (f .__name__ , bytearray (pickled_command ), env , includes ,
@@ -2625,7 +2596,7 @@ def countDistinct(col, *cols):
2625
2596
""" Return a new Column for distinct count of (col, *cols)
2626
2597
2627
2598
>>> from pyspark.sql import Dsl
2628
- >>> df.agg(Dsl.countDistinct(df.age, df.name).As ('c')).collect()
2599
+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias ('c')).collect()
2629
2600
[Row(c=2)]
2630
2601
"""
2631
2602
sc = SparkContext ._active_spark_context
@@ -2640,7 +2611,7 @@ def approxCountDistinct(col, rsd=None):
2640
2611
""" Return a new Column for approxiate distinct count of (col, *cols)
2641
2612
2642
2613
>>> from pyspark.sql import Dsl
2643
- >>> df.agg(Dsl.approxCountDistinct(df.age).As ('c')).collect()
2614
+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias ('c')).collect()
2644
2615
[Row(c=2)]
2645
2616
"""
2646
2617
sc = SparkContext ._active_spark_context
@@ -2655,7 +2626,7 @@ def udf(f, returnType=StringType()):
2655
2626
"""Create a user defined function (UDF)
2656
2627
2657
2628
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2658
- >>> df.select(slen(df.name).As ('slen')).collect()
2629
+ >>> df.select(slen(df.name).alias ('slen')).collect()
2659
2630
[Row(slen=5), Row(slen=3)]
2660
2631
"""
2661
2632
return UserDefinedFunction (f , returnType )
0 commit comments