Skip to content

Commit 97dee31

Browse files
committed
[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)
This builds on apache#5932 and should close apache#5932 as well. As an example: ```python df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() ``` Author: Reynold Xin <[email protected]> Author: kaka1992 <[email protected]> Closes apache#6072 from rxin/when-expr and squashes the following commits: 8f49201 [Reynold Xin] Throw exception if otherwise is applied twice. 0455eda [Reynold Xin] Reset run-tests. bfb9d9f [Reynold Xin] Updated documentation and test cases. 762f6a5 [Reynold Xin] Merge pull request apache#5932 from kaka1992/IFCASE 95724c6 [kaka1992] Update 8218d0a [kaka1992] Update 801009e [kaka1992] Update 76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case)
1 parent 8fd5535 commit 97dee31

File tree

6 files changed

+163
-2
lines changed

6 files changed

+163
-2
lines changed

python/pyspark/sql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Aggregation methods, returned by :func:`DataFrame.groupBy`.
3333
- L{DataFrameNaFunctions}
3434
Methods for handling missing data (null values).
35+
- L{DataFrameStatFunctions}
36+
Methods for statistics functionality.
3537
- L{functions}
3638
List of built-in functions available for :class:`DataFrame`.
3739
- L{types}

python/pyspark/sql/dataframe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,37 @@ def between(self, lowerBound, upperBound):
15461546
"""
15471547
return (self >= lowerBound) & (self <= upperBound)
15481548

1549+
@ignore_unicode_prefix
1550+
def when(self, condition, value):
1551+
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
1552+
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
1553+
1554+
See :func:`pyspark.sql.functions.when` for example usage.
1555+
1556+
:param condition: a boolean :class:`Column` expression.
1557+
:param value: a literal value, or a :class:`Column` expression.
1558+
1559+
"""
1560+
sc = SparkContext._active_spark_context
1561+
if not isinstance(condition, Column):
1562+
raise TypeError("condition should be a Column")
1563+
v = value._jc if isinstance(value, Column) else value
1564+
jc = sc._jvm.functions.when(condition._jc, v)
1565+
return Column(jc)
1566+
1567+
@ignore_unicode_prefix
1568+
def otherwise(self, value):
1569+
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
1570+
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
1571+
1572+
See :func:`pyspark.sql.functions.when` for example usage.
1573+
1574+
:param value: a literal value, or a :class:`Column` expression.
1575+
"""
1576+
v = value._jc if isinstance(value, Column) else value
1577+
jc = self._jc.otherwise(value)
1578+
return Column(jc)
1579+
15491580
def __repr__(self):
15501581
return 'Column<%s>' % self._jc.toString().encode('utf8')
15511582

python/pyspark/sql/functions.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232

3333
__all__ = [
3434
'approxCountDistinct',
35+
'coalesce',
3536
'countDistinct',
3637
'monotonicallyIncreasingId',
3738
'rand',
3839
'randn',
3940
'sparkPartitionId',
40-
'coalesce',
41-
'udf']
41+
'udf',
42+
'when']
4243

4344

4445
def _create_function(name, doc=""):
@@ -291,6 +292,27 @@ def struct(*cols):
291292
return Column(jc)
292293

293294

295+
def when(condition, value):
296+
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
297+
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
298+
299+
:param condition: a boolean :class:`Column` expression.
300+
:param value: a literal value, or a :class:`Column` expression.
301+
302+
>>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
303+
[Row(age=3), Row(age=4)]
304+
305+
>>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
306+
[Row(age=3), Row(age=None)]
307+
"""
308+
sc = SparkContext._active_spark_context
309+
if not isinstance(condition, Column):
310+
raise TypeError("condition should be a Column")
311+
v = value._jc if isinstance(value, Column) else value
312+
jc = sc._jvm.functions.when(condition._jc, v)
313+
return Column(jc)
314+
315+
294316
class UserDefinedFunction(object):
295317
"""
296318
User defined function in Python

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,67 @@ class Column(protected[sql] val expr: Expression) extends Logging {
327327
*/
328328
def eqNullSafe(other: Any): Column = this <=> other
329329

330+
/**
331+
* Evaluates a list of conditions and returns one of multiple possible result expressions.
332+
* If otherwise is not defined at the end, null is returned for unmatched conditions.
333+
*
334+
* {{{
335+
* // Example: encoding gender string column into integer.
336+
*
337+
* // Scala:
338+
* people.select(when(people("gender") === "male", 0)
339+
* .when(people("gender") === "female", 1)
340+
* .otherwise(2))
341+
*
342+
* // Java:
343+
* people.select(when(col("gender").equalTo("male"), 0)
344+
* .when(col("gender").equalTo("female"), 1)
345+
* .otherwise(2))
346+
* }}}
347+
*
348+
* @group expr_ops
349+
*/
350+
def when(condition: Column, value: Any):Column = this.expr match {
351+
case CaseWhen(branches: Seq[Expression]) =>
352+
CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
353+
case _ =>
354+
throw new IllegalArgumentException(
355+
"when() can only be applied on a Column previously generated by when() function")
356+
}
357+
358+
/**
359+
* Evaluates a list of conditions and returns one of multiple possible result expressions.
360+
* If otherwise is not defined at the end, null is returned for unmatched conditions.
361+
*
362+
* {{{
363+
* // Example: encoding gender string column into integer.
364+
*
365+
* // Scala:
366+
* people.select(when(people("gender") === "male", 0)
367+
* .when(people("gender") === "female", 1)
368+
* .otherwise(2))
369+
*
370+
* // Java:
371+
* people.select(when(col("gender").equalTo("male"), 0)
372+
* .when(col("gender").equalTo("female"), 1)
373+
* .otherwise(2))
374+
* }}}
375+
*
376+
* @group expr_ops
377+
*/
378+
def otherwise(value: Any):Column = this.expr match {
379+
case CaseWhen(branches: Seq[Expression]) =>
380+
if (branches.size % 2 == 0) {
381+
CaseWhen(branches :+ lit(value).expr)
382+
} else {
383+
throw new IllegalArgumentException(
384+
"otherwise() can only be applied once on a Column previously generated by when()")
385+
}
386+
case _ =>
387+
throw new IllegalArgumentException(
388+
"otherwise() can only be applied on a Column previously generated by when()")
389+
}
390+
330391
/**
331392
* True if the current column is between the lower bound and upper bound, inclusive.
332393
*

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,30 @@ object functions {
419419
*/
420420
def not(e: Column): Column = !e
421421

422+
/**
423+
* Evaluates a list of conditions and returns one of multiple possible result expressions.
424+
* If otherwise is not defined at the end, null is returned for unmatched conditions.
425+
*
426+
* {{{
427+
* // Example: encoding gender string column into integer.
428+
*
429+
* // Scala:
430+
* people.select(when(people("gender") === "male", 0)
431+
* .when(people("gender") === "female", 1)
432+
* .otherwise(2))
433+
*
434+
* // Java:
435+
* people.select(when(col("gender").equalTo("male"), 0)
436+
* .when(col("gender").equalTo("female"), 1)
437+
* .otherwise(2))
438+
* }}}
439+
*
440+
* @group normal_funcs
441+
*/
442+
def when(condition: Column, value: Any): Column = {
443+
CaseWhen(Seq(condition.expr, lit(value).expr))
444+
}
445+
422446
/**
423447
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
424448
*

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,27 @@ class ColumnExpressionSuite extends QueryTest {
255255
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
256256
}
257257

258+
test("SPARK-7321 when conditional statements") {
259+
val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value")
260+
261+
checkAnswer(
262+
testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
263+
Seq(Row(-1), Row(-2), Row(0))
264+
)
265+
266+
// Without the ending otherwise, return null for unmatched conditions.
267+
// Also test putting a non-literal value in the expression.
268+
checkAnswer(
269+
testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)),
270+
Seq(Row(-1), Row(-2), Row(null))
271+
)
272+
273+
// Test error handling for invalid expressions.
274+
intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) }
275+
intercept[IllegalArgumentException] { $"key".otherwise(-1) }
276+
intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) }
277+
}
278+
258279
test("sqrt") {
259280
checkAnswer(
260281
testData.select(sqrt('key)).orderBy('key.asc),

0 commit comments

Comments
 (0)