62
62
"StringType" , "BinaryType" , "BooleanType" , "DateType" , "TimestampType" , "DecimalType" ,
63
63
"DoubleType" , "FloatType" , "ByteType" , "IntegerType" , "LongType" ,
64
64
"ShortType" , "ArrayType" , "MapType" , "StructField" , "StructType" ,
65
- "SQLContext" , "HiveContext" , "DataFrame" , "GroupedDataFrame" , "Column" , "Row" ,
65
+ "SQLContext" , "HiveContext" , "DataFrame" , "GroupedDataFrame" , "Column" , "Row" , "Dsl" ,
66
66
"SchemaRDD" ]
67
67
68
68
@@ -2121,6 +2121,8 @@ def sort(self, *cols):
2121
2121
2122
2122
>>> df.sort(df.age.desc()).collect()
2123
2123
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
2124
+ >>> df.sortBy(df.age.desc()).collect()
2125
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
2124
2126
"""
2125
2127
if not cols :
2126
2128
raise ValueError ("should sort by at least one column" )
@@ -2427,34 +2429,34 @@ def _scalaMethod(name):
2427
2429
return '' .join (SCALA_METHOD_MAPPINGS .get (c , c ) for c in name )
2428
2430
2429
2431
2430
- def _unary_op (name ):
2432
+ def _unary_op (name , doc = "unary operator" ):
2431
2433
""" Create a method for given unary operator """
2432
2434
def _ (self ):
2433
2435
jc = getattr (self ._jc , _scalaMethod (name ))()
2434
2436
return Column (jc , self .sql_ctx )
2437
+ _ .__doc__ = doc
2435
2438
return _
2436
2439
2437
2440
2438
- def _bin_op (name ):
2441
+ def _bin_op (name , doc = "binary operator" ):
2439
2442
""" Create a method for given binary operator
2440
-
2441
- Keyword arguments:
2442
- pass_literal_through -- whether to pass literal value directly through to the JVM.
2443
2443
"""
2444
2444
def _ (self , other ):
2445
2445
jc = other ._jc if isinstance (other , Column ) else other
2446
2446
njc = getattr (self ._jc , _scalaMethod (name ))(jc )
2447
2447
return Column (njc , self .sql_ctx )
2448
+ _ .__doc__ = doc
2448
2449
return _
2449
2450
2450
2451
2451
- def _reverse_op (name ):
2452
+ def _reverse_op (name , doc = "binary operator" ):
2452
2453
""" Create a method for binary operator (this object is on right side)
2453
2454
"""
2454
2455
def _ (self , other ):
2455
2456
jother = _create_column_from_literal (other )
2456
2457
jc = getattr (jother , _scalaMethod (name ))(self ._jc )
2457
2458
return Column (jc , self .sql_ctx )
2459
+ _ .__doc__ = doc
2458
2460
return _
2459
2461
2460
2462
@@ -2491,8 +2493,6 @@ def __init__(self, jc, sql_ctx=None):
2491
2493
__rdiv__ = _reverse_op ("/" )
2492
2494
__rmod__ = _reverse_op ("%" )
2493
2495
__abs__ = _unary_op ("abs" )
2494
- abs = _unary_op ("abs" )
2495
- sqrt = _unary_op ("sqrt" )
2496
2496
2497
2497
# logistic operators
2498
2498
__eq__ = _bin_op ("===" )
@@ -2501,36 +2501,25 @@ def __init__(self, jc, sql_ctx=None):
2501
2501
__le__ = _bin_op ("<=" )
2502
2502
__ge__ = _bin_op (">=" )
2503
2503
__gt__ = _bin_op (">" )
2504
- # `and`, `or`, `not` cannot be overloaded in Python
2505
- And = _bin_op ('&&' )
2506
- Or = _bin_op ('||' )
2507
- Not = _unary_op ('unary_!' )
2508
-
2509
- # bitwise operators
2510
- __and__ = _bin_op ("&" )
2511
- __or__ = _bin_op ("|" )
2512
- __invert__ = _unary_op ("unary_~" )
2513
- __xor__ = _bin_op ("^" )
2514
- # __lshift__ = _bin_op("<<")
2515
- # __rshift__ = _bin_op(">>")
2516
- __rand__ = _bin_op ("&" )
2517
- __ror__ = _bin_op ("|" )
2518
- __rxor__ = _bin_op ("^" )
2519
- # __rlshift__ = _reverse_op("<<")
2520
- # __rrshift__ = _reverse_op(">>")
2504
+
2505
+ # `and`, `or`, `not` cannot be overloaded in Python,
2506
+ # so use bitwise operators as boolean operators
2507
+ __and__ = _bin_op ('&&' )
2508
+ __or__ = _bin_op ('||' )
2509
+ __invert__ = _unary_op ('unary_!' )
2510
+ __rand__ = _bin_op ("&&" )
2511
+ __ror__ = _bin_op ("||" )
2521
2512
2522
2513
# container operators
2523
2514
__contains__ = _bin_op ("contains" )
2524
2515
__getitem__ = _bin_op ("getItem" )
2525
- # __getattr__ = _bin_op("getField")
2516
+ getField = _bin_op ("getField" , "An expression that gets a field by name in a StructField. " )
2526
2517
2527
2518
# string methods
2528
2519
rlike = _bin_op ("rlike" )
2529
2520
like = _bin_op ("like" )
2530
2521
startswith = _bin_op ("startsWith" )
2531
2522
endswith = _bin_op ("endsWith" )
2532
- upper = _unary_op ("upper" )
2533
- lower = _unary_op ("lower" )
2534
2523
2535
2524
def substr (self , startPos , length ):
2536
2525
"""
@@ -2558,12 +2547,20 @@ def substr(self, startPos, length):
2558
2547
asc = _unary_op ("asc" )
2559
2548
desc = _unary_op ("desc" )
2560
2549
2561
- isNull = _unary_op ("isNull" )
2562
- isNotNull = _unary_op ("isNotNull" )
2550
+ isNull = _unary_op ("isNull" , "True if the current expression is null." )
2551
+ isNotNull = _unary_op ("isNotNull" , "True if the current expression is not null." )
2563
2552
2564
2553
# `as` is keyword
2565
- def As (self , alias ):
2554
+ def alias (self , alias ):
2555
+ """Return a alias for this column
2556
+
2557
+ >>> df.age.As("age2").collect()
2558
+ [Row(age2=2), Row(age2=5)]
2559
+ >>> df.age.alias("age2").collect()
2560
+ [Row(age2=2), Row(age2=5)]
2561
+ """
2566
2562
return Column (getattr (self ._jc , "as" )(alias ), self .sql_ctx )
2563
+ As = alias
2567
2564
2568
2565
def cast (self , dataType ):
2569
2566
""" Convert the column into type `dataType`
@@ -2580,27 +2577,44 @@ def cast(self, dataType):
2580
2577
return Column (self ._jc .cast (jdt ), self .sql_ctx )
2581
2578
2582
2579
2583
- def _aggregate_func (name ):
2580
+ def _aggregate_func (name , doc = "" ):
2584
2581
""" Create a function for aggregator by name"""
2585
2582
def _ (col ):
2586
2583
sc = SparkContext ._active_spark_context
2587
2584
jc = getattr (sc ._jvm .Dsl , name )(_to_java_column (col ))
2588
2585
return Column (jc )
2589
-
2586
+ _ .__name__ = name
2587
+ _ .__doc__ = doc
2590
2588
return staticmethod (_ )
2591
2589
2592
2590
2593
2591
class Dsl (object ):
2594
2592
"""
2595
2593
A collections of builtin aggregators
2596
2594
"""
2597
- AGGS = [
2598
- 'lit' , 'col' , 'column' , 'upper' , 'lower' , 'sqrt' , 'abs' ,
2599
- 'min' , 'max' , 'first' , 'last' , 'count' , 'avg' , 'mean' , 'sum' , 'sumDistinct' ,
2600
- ]
2601
- for _name in AGGS :
2602
- locals ()[_name ] = _aggregate_func (_name )
2603
- del _name
2595
+ DSLS = {
2596
+ 'lit' : 'Creates a [[Column]] of literal value.' ,
2597
+ 'col' : 'Returns a [[Column]] based on the given column name.' ,
2598
+ 'column' : 'Returns a [[Column]] based on the given column name.' ,
2599
+ 'upper' : 'Converts a string expression to upper case.' ,
2600
+ 'lower' : 'Converts a string expression to upper case.' ,
2601
+ 'sqrt' : 'Computes the square root of the specified float value.' ,
2602
+ 'abs' : 'Computes the absolutle value.' ,
2603
+
2604
+ 'max' : 'Aggregate function: returns the maximum value of the expression in a group.' ,
2605
+ 'min' : 'Aggregate function: returns the minimum value of the expression in a group.' ,
2606
+ 'first' : 'Aggregate function: returns the first value in a group.' ,
2607
+ 'last' : 'Aggregate function: returns the last value in a group.' ,
2608
+ 'count' : 'Aggregate function: returns the number of items in a group.' ,
2609
+ 'sum' : 'Aggregate function: returns the sum of all values in the expression.' ,
2610
+ 'avg' : 'Aggregate function: returns the average of the values in a group.' ,
2611
+ 'mean' : 'Aggregate function: returns the average of the values in a group.' ,
2612
+ 'sumDistinct' : 'Aggregate function: returns the sum of distinct values in the expression.' ,
2613
+ }
2614
+
2615
+ for _name , _doc in DSLS .items ():
2616
+ locals ()[_name ] = _aggregate_func (_name , _doc )
2617
+ del _name , _doc
2604
2618
2605
2619
@staticmethod
2606
2620
def countDistinct (col , * cols ):
0 commit comments