Skip to content

Commit 83c92fe

Browse files
author
Davies Liu
committed
address comments
1 parent c052f6f commit 83c92fe

File tree

2 files changed

+59
-75
lines changed

2 files changed

+59
-75
lines changed

python/pyspark/sql.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
6363
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
6464
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
65-
"SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
65+
"SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl",
6666
"SchemaRDD"]
6767

6868

@@ -2121,6 +2121,8 @@ def sort(self, *cols):
21212121
21222122
>>> df.sort(df.age.desc()).collect()
21232123
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
2124+
>>> df.sortBy(df.age.desc()).collect()
2125+
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
21242126
"""
21252127
if not cols:
21262128
raise ValueError("should sort by at least one column")
@@ -2427,34 +2429,34 @@ def _scalaMethod(name):
24272429
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
24282430

24292431

2430-
def _unary_op(name):
2432+
def _unary_op(name, doc="unary operator"):
24312433
""" Create a method for given unary operator """
24322434
def _(self):
24332435
jc = getattr(self._jc, _scalaMethod(name))()
24342436
return Column(jc, self.sql_ctx)
2437+
_.__doc__ = doc
24352438
return _
24362439

24372440

2438-
def _bin_op(name):
2441+
def _bin_op(name, doc="binary operator"):
24392442
""" Create a method for given binary operator
2440-
2441-
Keyword arguments:
2442-
pass_literal_through -- whether to pass literal value directly through to the JVM.
24432443
"""
24442444
def _(self, other):
24452445
jc = other._jc if isinstance(other, Column) else other
24462446
njc = getattr(self._jc, _scalaMethod(name))(jc)
24472447
return Column(njc, self.sql_ctx)
2448+
_.__doc__ = doc
24482449
return _
24492450

24502451

2451-
def _reverse_op(name):
2452+
def _reverse_op(name, doc="binary operator"):
24522453
""" Create a method for binary operator (this object is on right side)
24532454
"""
24542455
def _(self, other):
24552456
jother = _create_column_from_literal(other)
24562457
jc = getattr(jother, _scalaMethod(name))(self._jc)
24572458
return Column(jc, self.sql_ctx)
2459+
_.__doc__ = doc
24582460
return _
24592461

24602462

@@ -2491,8 +2493,6 @@ def __init__(self, jc, sql_ctx=None):
24912493
__rdiv__ = _reverse_op("/")
24922494
__rmod__ = _reverse_op("%")
24932495
__abs__ = _unary_op("abs")
2494-
abs = _unary_op("abs")
2495-
sqrt = _unary_op("sqrt")
24962496

24972497
# logistic operators
24982498
__eq__ = _bin_op("===")
@@ -2501,36 +2501,25 @@ def __init__(self, jc, sql_ctx=None):
25012501
__le__ = _bin_op("<=")
25022502
__ge__ = _bin_op(">=")
25032503
__gt__ = _bin_op(">")
2504-
# `and`, `or`, `not` cannot be overloaded in Python
2505-
And = _bin_op('&&')
2506-
Or = _bin_op('||')
2507-
Not = _unary_op('unary_!')
2508-
2509-
# bitwise operators
2510-
__and__ = _bin_op("&")
2511-
__or__ = _bin_op("|")
2512-
__invert__ = _unary_op("unary_~")
2513-
__xor__ = _bin_op("^")
2514-
# __lshift__ = _bin_op("<<")
2515-
# __rshift__ = _bin_op(">>")
2516-
__rand__ = _bin_op("&")
2517-
__ror__ = _bin_op("|")
2518-
__rxor__ = _bin_op("^")
2519-
# __rlshift__ = _reverse_op("<<")
2520-
# __rrshift__ = _reverse_op(">>")
2504+
2505+
# `and`, `or`, `not` cannot be overloaded in Python,
2506+
# so use bitwise operators as boolean operators
2507+
__and__ = _bin_op('&&')
2508+
__or__ = _bin_op('||')
2509+
__invert__ = _unary_op('unary_!')
2510+
__rand__ = _bin_op("&&")
2511+
__ror__ = _bin_op("||")
25212512

25222513
# container operators
25232514
__contains__ = _bin_op("contains")
25242515
__getitem__ = _bin_op("getItem")
2525-
# __getattr__ = _bin_op("getField")
2516+
getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
25262517

25272518
# string methods
25282519
rlike = _bin_op("rlike")
25292520
like = _bin_op("like")
25302521
startswith = _bin_op("startsWith")
25312522
endswith = _bin_op("endsWith")
2532-
upper = _unary_op("upper")
2533-
lower = _unary_op("lower")
25342523

25352524
def substr(self, startPos, length):
25362525
"""
@@ -2558,12 +2547,20 @@ def substr(self, startPos, length):
25582547
asc = _unary_op("asc")
25592548
desc = _unary_op("desc")
25602549

2561-
isNull = _unary_op("isNull")
2562-
isNotNull = _unary_op("isNotNull")
2550+
isNull = _unary_op("isNull", "True if the current expression is null.")
2551+
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
25632552

25642553
# `as` is keyword
2565-
def As(self, alias):
2554+
def alias(self, alias):
2555+
"""Return a alias for this column
2556+
2557+
>>> df.age.As("age2").collect()
2558+
[Row(age2=2), Row(age2=5)]
2559+
>>> df.age.alias("age2").collect()
2560+
[Row(age2=2), Row(age2=5)]
2561+
"""
25662562
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
2563+
As = alias
25672564

25682565
def cast(self, dataType):
25692566
""" Convert the column into type `dataType`
@@ -2580,27 +2577,44 @@ def cast(self, dataType):
25802577
return Column(self._jc.cast(jdt), self.sql_ctx)
25812578

25822579

2583-
def _aggregate_func(name):
2580+
def _aggregate_func(name, doc=""):
25842581
""" Create a function for aggregator by name"""
25852582
def _(col):
25862583
sc = SparkContext._active_spark_context
25872584
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
25882585
return Column(jc)
2589-
2586+
_.__name__ = name
2587+
_.__doc__ = doc
25902588
return staticmethod(_)
25912589

25922590

25932591
class Dsl(object):
25942592
"""
25952593
A collections of builtin aggregators
25962594
"""
2597-
AGGS = [
2598-
'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
2599-
'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
2600-
]
2601-
for _name in AGGS:
2602-
locals()[_name] = _aggregate_func(_name)
2603-
del _name
2595+
DSLS = {
2596+
'lit': 'Creates a [[Column]] of literal value.',
2597+
'col': 'Returns a [[Column]] based on the given column name.',
2598+
'column': 'Returns a [[Column]] based on the given column name.',
2599+
'upper': 'Converts a string expression to upper case.',
2600+
'lower': 'Converts a string expression to upper case.',
2601+
'sqrt': 'Computes the square root of the specified float value.',
2602+
'abs': 'Computes the absolutle value.',
2603+
2604+
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
2605+
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
2606+
'first': 'Aggregate function: returns the first value in a group.',
2607+
'last': 'Aggregate function: returns the last value in a group.',
2608+
'count': 'Aggregate function: returns the number of items in a group.',
2609+
'sum': 'Aggregate function: returns the sum of all values in the expression.',
2610+
'avg': 'Aggregate function: returns the average of the values in a group.',
2611+
'mean': 'Aggregate function: returns the average of the values in a group.',
2612+
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
2613+
}
2614+
2615+
for _name, _doc in DSLS.items():
2616+
locals()[_name] = _aggregate_func(_name, _doc)
2617+
del _name, _doc
26042618

26052619
@staticmethod
26062620
def countDistinct(col, *cols):

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,6 @@ trait Column extends DataFrame {
115115
*/
116116
def unary_- : Column = exprToColumn(UnaryMinus(expr))
117117

118-
/**
119-
* Bitwise NOT.
120-
* {{{
121-
* // Scala: select the flags column and negate every bit.
122-
* df.select( ~df("flags") )
123-
* }}}
124-
*/
125-
def unary_~ : Column = exprToColumn(BitwiseNot(expr))
126-
127118
/**
128119
* Inversion of boolean expression, i.e. NOT.
129120
* {{
@@ -362,27 +353,6 @@ trait Column extends DataFrame {
362353
*/
363354
def and(other: Column): Column = this && other
364355

365-
/**
366-
* Bitwise AND.
367-
*/
368-
def & (other: Any): Column = constructColumn(other) { o =>
369-
BitwiseAnd(expr, o.expr)
370-
}
371-
372-
/**
373-
* Bitwise OR with an expression.
374-
*/
375-
def | (other: Any): Column = constructColumn(other) { o =>
376-
BitwiseOr(expr, o.expr)
377-
}
378-
379-
/**
380-
* Bitwise XOR with an expression.
381-
*/
382-
def ^ (other: Any): Column = constructColumn(other) { o =>
383-
BitwiseXor(expr, o.expr)
384-
}
385-
386356
/**
387357
* Sum of this expression and another expression.
388358
* {{{
@@ -527,16 +497,16 @@ trait Column extends DataFrame {
527497
* @param startPos expression for the starting position.
528498
* @param len expression for the length of the substring.
529499
*/
530-
def substr(startPos: Column, len: Column): Column = constructColumn(null) {
531-
Substring(expr, startPos.expr, len.expr)
532-
}
500+
def substr(startPos: Column, len: Column): Column =
501+
exprToColumn(Substring(expr, startPos.expr, len.expr), computable = false)
533502

534503
/**
535504
* An expression that returns a substring.
536505
* @param startPos starting position.
537506
* @param len length of the substring.
538507
*/
539-
def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
508+
def substr(startPos: Int, len: Int): Column =
509+
exprToColumn(Substring(expr, lit(startPos).expr, lit(len).expr))
540510

541511
def contains(other: Any): Column = constructColumn(other) { o =>
542512
Contains(expr, o.expr)

0 commit comments

Comments
 (0)