Skip to content

Commit 3ce54e1

Browse files
committed
add CaseKeyWhen
1 parent 4f87e95 commit 3ce54e1

File tree

8 files changed

+159
-85
lines changed

8 files changed

+159
-85
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
296296
| LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
297297
| IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
298298
{ case c ~ t ~ f => If(c, t, f) }
299-
| CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
299+
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
300300
(ELSE ~> expression).? <~ END ^^ {
301301
case casePart ~ altPart ~ elsePart =>
302-
val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
303-
Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
304-
}
305-
CaseWhen(altExprs ++ elsePart.toList)
302+
val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
303+
Seq(whenExpr, thenExpr)
304+
} ++ elsePart
305+
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
306306
}
307307
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
308308
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -631,31 +631,24 @@ trait HiveTypeCoercion {
631631
import HiveTypeCoercion._
632632

633633
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
634-
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
635-
val valueTypes = branches.sliding(2, 2).map {
636-
case Seq(_, value) => value.dataType
637-
case Seq(elseVal) => elseVal.dataType
638-
}.toSeq
639-
640-
logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
641-
642-
if (valueTypes.distinct.size > 1) {
643-
val commonType = valueTypes.reduce { (v1, v2) =>
644-
findTightestCommonType(v1, v2)
645-
.getOrElse(sys.error(
646-
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
647-
}
648-
val transformedBranches = branches.sliding(2, 2).map {
649-
case Seq(cond, value) if value.dataType != commonType =>
650-
Seq(cond, Cast(value, commonType))
651-
case Seq(elseVal) if elseVal.dataType != commonType =>
652-
Seq(Cast(elseVal, commonType))
653-
case s => s
654-
}.reduce(_ ++ _)
655-
CaseWhen(transformedBranches)
656-
} else {
657-
// Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
658-
cw
634+
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
635+
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
636+
val commonType = cw.valueTypes.reduce { (v1, v2) =>
637+
findTightestCommonType(v1, v2).getOrElse(sys.error(
638+
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
639+
}
640+
val transformedBranches = cw.branches.sliding(2, 2).map {
641+
case Seq(when, value) if value.dataType != commonType =>
642+
Seq(when, Cast(value, commonType))
643+
case Seq(elseVal) if elseVal.dataType != commonType =>
644+
Seq(Cast(elseVal, commonType))
645+
case s => s
646+
}.reduce(_ ++ _)
647+
cw match {
648+
case _: CaseWhen =>
649+
CaseWhen(transformedBranches)
650+
case CaseKeyWhen(key, _) =>
651+
CaseKeyWhen(key, transformedBranches)
659652
}
660653
}
661654
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] {
6464
* Returns true if all the children of this expression have been resolved to a specific schema
6565
* and false if any still contains any unresolved placeholders.
6666
*/
67-
def childrenResolved: Boolean = !children.exists(!_.resolved)
67+
def childrenResolved: Boolean = children.forall(_.resolved)
6868

6969
/**
7070
* Returns a string representation of this expression that does not have developer centric

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
353353
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
354354
}
355355

356+
trait CaseWhenLike extends Expression {
357+
self: Product =>
358+
359+
type EvaluatedType = Any
360+
361+
// Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
362+
// element is the value for the default catch-all case (if provided).
363+
// Hence, `branches` consists of at least two elements, and can have an odd or even length.
364+
def branches: Seq[Expression]
365+
366+
@transient lazy val whenList =
367+
branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
368+
@transient lazy val thenList =
369+
branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
370+
val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
371+
372+
// both then and else val should be considered.
373+
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
374+
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
375+
376+
override def dataType: DataType = {
377+
if (!resolved) {
378+
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
379+
}
380+
valueTypes.head
381+
}
382+
383+
override def nullable: Boolean = {
384+
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
385+
thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
386+
}
387+
}
388+
356389
// scalastyle:off
357390
/**
358391
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
359392
* Refer to this link for the corresponding semantics:
360393
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
361-
*
362-
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
363-
* translated to this form at parsing time. Namely, such a statement gets translated to
364-
* "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
365-
*
366-
* Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
367-
* element is the value for the default catch-all case (if provided). Hence, `branches` consists of
368-
* at least two elements, and can have an odd or even length.
369394
*/
370395
// scalastyle:on
371-
case class CaseWhen(branches: Seq[Expression]) extends Expression {
372-
type EvaluatedType = Any
396+
case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
397+
398+
// Use private[this] Array to speed up evaluation.
399+
@transient private[this] lazy val branchesArr = branches.toArray
373400

374401
override def children: Seq[Expression] = branches
375402

376-
override def dataType: DataType = {
377-
if (!resolved) {
378-
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
403+
override lazy val resolved: Boolean =
404+
childrenResolved &&
405+
whenList.forall(_.dataType == BooleanType) &&
406+
valueTypesEqual
407+
408+
/** Written in imperative fashion for performance considerations. */
409+
override def eval(input: Row): Any = {
410+
val len = branchesArr.length
411+
var i = 0
412+
// If all branches fail and an elseVal is not provided, the whole statement
413+
// defaults to null, according to Hive's semantics.
414+
while (i < len - 1) {
415+
if (branchesArr(i).eval(input) == true) {
416+
return branchesArr(i + 1).eval(input)
417+
}
418+
i += 2
419+
}
420+
var res: Any = null
421+
if (i == len - 1) {
422+
res = branchesArr(i).eval(input)
379423
}
380-
branches(1).dataType
424+
return res
381425
}
382426

427+
override def toString: String = {
428+
"CASE" + branches.sliding(2, 2).map {
429+
case Seq(cond, value) => s" WHEN $cond THEN $value"
430+
case Seq(elseValue) => s" ELSE $elseValue"
431+
}.mkString
432+
}
433+
}
434+
435+
// scalastyle:off
436+
/**
437+
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
438+
* Refer to this link for the corresponding semantics:
439+
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
440+
*/
441+
// scalastyle:on
442+
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
443+
444+
// Use private[this] Array to speed up evaluation.
383445
@transient private[this] lazy val branchesArr = branches.toArray
384-
@transient private[this] lazy val predicates =
385-
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
386-
@transient private[this] lazy val values =
387-
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
388-
@transient private[this] lazy val elseValue =
389-
if (branches.length % 2 == 0) None else Option(branches.last)
390446

391-
override def nullable: Boolean = {
392-
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
393-
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
394-
}
447+
override def children: Seq[Expression] = key +: branches
395448

396-
override lazy val resolved: Boolean = {
397-
if (!childrenResolved) {
398-
false
399-
} else {
400-
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
401-
// both then and else val should be considered.
402-
val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
403-
allCondBooleans && dataTypesEqual
404-
}
405-
}
449+
override lazy val resolved: Boolean =
450+
childrenResolved && valueTypesEqual
406451

407452
/** Written in imperative fashion for performance considerations. */
408453
override def eval(input: Row): Any = {
454+
val evaluatedKey = key.eval(input)
409455
val len = branchesArr.length
410456
var i = 0
411457
// If all branches fail and an elseVal is not provided, the whole statement
412458
// defaults to null, according to Hive's semantics.
413-
var res: Any = null
414459
while (i < len - 1) {
415-
if (branchesArr(i).eval(input) == true) {
416-
res = branchesArr(i + 1).eval(input)
417-
return res
460+
if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
461+
return branchesArr(i + 1).eval(input)
418462
}
419463
i += 2
420464
}
465+
var res: Any = null
421466
if (i == len - 1) {
422467
res = branchesArr(i).eval(input)
423468
}
424-
res
469+
return res
470+
}
471+
472+
private def equalNullSafe(l: Any, r: Any) = {
473+
if (l == null && r == null) {
474+
true
475+
} else if (l == null || r == null) {
476+
false
477+
} else {
478+
l == r
479+
}
425480
}
426481

427482
override def toString: String = {
428-
"CASE" + branches.sliding(2, 2).map {
483+
s"CASE $key" + branches.sliding(2, 2).map {
429484
case Seq(cond, value) => s" WHEN $cond THEN $value"
430485
case Seq(elseValue) => s" ELSE $elseValue"
431486
}.mkString

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
850850
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
851851
}
852852

853+
test("case key when") {
854+
val row = create_row(null, 1, 2, "a", "b", "c")
855+
val c1 = 'a.int.at(0)
856+
val c2 = 'a.int.at(1)
857+
val c3 = 'a.int.at(2)
858+
val c4 = 'a.string.at(3)
859+
val c5 = 'a.string.at(4)
860+
val c6 = 'a.string.at(5)
861+
862+
val literalNull = Literal.create(null, BooleanType)
863+
val literalInt = Literal(1)
864+
val literalString = Literal("a")
865+
866+
checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
867+
checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
868+
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
869+
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
870+
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
871+
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
872+
873+
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
874+
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
875+
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
876+
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
877+
}
878+
853879
test("complex type") {
854880
val row = create_row(
855881
"^Ba*n", // 0

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
357357
* TODO: This can be optimized to use broadcast join when replacementMap is large.
358358
*/
359359
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
360-
val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
361-
df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
362-
lit(target).cast(col.dataType).expr :: Nil
360+
val keyExpr = df.col(col.name).expr
361+
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
362+
val branches = replacementMap.flatMap { case (source, target) =>
363+
Seq(buildExpr(source), buildExpr(target))
363364
}.toSeq
364-
new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
365+
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
365366
}
366367

367368
private def convertToDouble(v: Any): Double = v match {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,16 +1249,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
12491249
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
12501250
CaseWhen(branches.map(nodeToExpr))
12511251
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
1252-
val transformed = branches.drop(1).sliding(2, 2).map {
1253-
case Seq(condVal, value) =>
1254-
// FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
1255-
// Hence effectful / non-deterministic key expressions are *not* supported at the moment.
1256-
// We should consider adding new Expressions to get around this.
1257-
Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)),
1258-
nodeToExpr(value))
1259-
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
1260-
}.toSeq.reduce(_ ++ _)
1261-
CaseWhen(transformed)
1252+
val keyExpr = nodeToExpr(branches.head)
1253+
CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
12621254

12631255
/* Complex datatype manipulation */
12641256
case Token("[", child :: ordinal :: Nil) =>

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,11 @@ class SQLQuerySuite extends QueryTest {
751751
(6, "c", 0, 6)
752752
).map(i => Row(i._1, i._2, i._3, i._4)))
753753
}
754+
755+
test("test case key when") {
756+
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
757+
checkAnswer(
758+
sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"),
759+
Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil)
760+
}
754761
}

0 commit comments

Comments
 (0)