Skip to content

Commit c34fc19

Browse files
0x0FFFdavies
authored andcommitted
[SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator
This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF <[email protected]> Closes #8658 from 0x0FFF/SPARK-9014.
1 parent d74c6a1 commit c34fc19

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

python/pyspark/sql/column.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def _(self):
9191
return _
9292

9393

94+
def _bin_func_op(name, reverse=False, doc="binary function"):
95+
def _(self, other):
96+
sc = SparkContext._active_spark_context
97+
fn = getattr(sc._jvm.functions, name)
98+
jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
99+
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
100+
return Column(njc)
101+
_.__doc__ = doc
102+
return _
103+
104+
94105
def _bin_op(name, doc="binary operator"):
95106
""" Create a method for given binary operator
96107
"""
@@ -151,6 +162,8 @@ def __init__(self, jc):
151162
__rdiv__ = _reverse_op("divide")
152163
__rtruediv__ = _reverse_op("divide")
153164
__rmod__ = _reverse_op("mod")
165+
__pow__ = _bin_func_op("pow")
166+
__rpow__ = _bin_func_op("pow", reverse=True)
154167

155168
# logistic operators
156169
__eq__ = _bin_op("equalTo")

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def test_column_operators(self):
568568
cs = self.df.value
569569
c = ci == cs
570570
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
571-
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
571+
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
572572
self.assertTrue(all(isinstance(c, Column) for c in rcc))
573573
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
574574
self.assertTrue(all(isinstance(c, Column) for c in cb))

0 commit comments

Comments
 (0)