@@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
353
353
override def toString : String = s " if ( $predicate) $trueValue else $falseValue"
354
354
}
355
355
356
+ trait CaseWhenLike extends Expression {
357
+ self : Product =>
358
+
359
+ type EvaluatedType = Any
360
+
361
+ // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
362
+ // element is the value for the default catch-all case (if provided).
363
+ // Hence, `branches` consists of at least two elements, and can have an odd or even length.
364
+ def branches : Seq [Expression ]
365
+
366
+ @ transient lazy val whenList =
367
+ branches.sliding(2 , 2 ).collect { case Seq (whenExpr, _) => whenExpr }.toSeq
368
+ @ transient lazy val thenList =
369
+ branches.sliding(2 , 2 ).collect { case Seq (_, thenExpr) => thenExpr }.toSeq
370
+ val elseValue = if (branches.length % 2 == 0 ) None else Option (branches.last)
371
+
372
+ // both then and else val should be considered.
373
+ def valueTypes : Seq [DataType ] = (thenList ++ elseValue).map(_.dataType)
374
+ def valueTypesEqual : Boolean = valueTypes.distinct.size <= 1
375
+
376
+ override def dataType : DataType = {
377
+ if (! resolved) {
378
+ throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
379
+ }
380
+ valueTypes.head
381
+ }
382
+
383
+ override def nullable : Boolean = {
384
+ // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
385
+ thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true ))
386
+ }
387
+ }
388
+
356
389
// scalastyle:off
357
390
/**
358
391
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
359
392
* Refer to this link for the corresponding semantics:
360
393
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
361
- *
362
- * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
363
- * translated to this form at parsing time. Namely, such a statement gets translated to
364
- * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
365
- *
366
- * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
367
- * element is the value for the default catch-all case (if provided). Hence, `branches` consists of
368
- * at least two elements, and can have an odd or even length.
369
394
*/
370
395
// scalastyle:on
371
- case class CaseWhen (branches : Seq [Expression ]) extends Expression {
372
- type EvaluatedType = Any
396
+ case class CaseWhen (branches : Seq [Expression ]) extends CaseWhenLike {
397
+
398
+ // Use private[this] Array to speed up evaluation.
399
+ @ transient private [this ] lazy val branchesArr = branches.toArray
373
400
374
401
override def children : Seq [Expression ] = branches
375
402
376
- override def dataType : DataType = {
377
- if (! resolved) {
378
- throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
403
+ override lazy val resolved : Boolean =
404
+ childrenResolved &&
405
+ whenList.forall(_.dataType == BooleanType ) &&
406
+ valueTypesEqual
407
+
408
+ /** Written in imperative fashion for performance considerations. */
409
+ override def eval (input : Row ): Any = {
410
+ val len = branchesArr.length
411
+ var i = 0
412
+ // If all branches fail and an elseVal is not provided, the whole statement
413
+ // defaults to null, according to Hive's semantics.
414
+ while (i < len - 1 ) {
415
+ if (branchesArr(i).eval(input) == true ) {
416
+ return branchesArr(i + 1 ).eval(input)
417
+ }
418
+ i += 2
419
+ }
420
+ var res : Any = null
421
+ if (i == len - 1 ) {
422
+ res = branchesArr(i).eval(input)
379
423
}
380
- branches( 1 ).dataType
424
+ return res
381
425
}
382
426
427
+ override def toString : String = {
428
+ " CASE" + branches.sliding(2 , 2 ).map {
429
+ case Seq (cond, value) => s " WHEN $cond THEN $value"
430
+ case Seq (elseValue) => s " ELSE $elseValue"
431
+ }.mkString
432
+ }
433
+ }
434
+
435
+ // scalastyle:off
436
+ /**
437
+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
438
+ * Refer to this link for the corresponding semantics:
439
+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
440
+ */
441
+ // scalastyle:on
442
+ case class CaseKeyWhen (key : Expression , branches : Seq [Expression ]) extends CaseWhenLike {
443
+
444
+ // Use private[this] Array to speed up evaluation.
383
445
@ transient private [this ] lazy val branchesArr = branches.toArray
384
- @ transient private [this ] lazy val predicates =
385
- branches.sliding(2 , 2 ).collect { case Seq (cond, _) => cond }.toSeq
386
- @ transient private [this ] lazy val values =
387
- branches.sliding(2 , 2 ).collect { case Seq (_, value) => value }.toSeq
388
- @ transient private [this ] lazy val elseValue =
389
- if (branches.length % 2 == 0 ) None else Option (branches.last)
390
446
391
- override def nullable : Boolean = {
392
- // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
393
- values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true ))
394
- }
447
+ override def children : Seq [Expression ] = key +: branches
395
448
396
- override lazy val resolved : Boolean = {
397
- if (! childrenResolved) {
398
- false
399
- } else {
400
- val allCondBooleans = predicates.forall(_.dataType == BooleanType )
401
- // both then and else val should be considered.
402
- val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
403
- allCondBooleans && dataTypesEqual
404
- }
405
- }
449
+ override lazy val resolved : Boolean =
450
+ childrenResolved && valueTypesEqual
406
451
407
452
/** Written in imperative fashion for performance considerations. */
408
453
override def eval (input : Row ): Any = {
454
+ val evaluatedKey = key.eval(input)
409
455
val len = branchesArr.length
410
456
var i = 0
411
457
// If all branches fail and an elseVal is not provided, the whole statement
412
458
// defaults to null, according to Hive's semantics.
413
- var res : Any = null
414
459
while (i < len - 1 ) {
415
- if (branchesArr(i).eval(input) == true ) {
416
- res = branchesArr(i + 1 ).eval(input)
417
- return res
460
+ if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
461
+ return branchesArr(i + 1 ).eval(input)
418
462
}
419
463
i += 2
420
464
}
465
+ var res : Any = null
421
466
if (i == len - 1 ) {
422
467
res = branchesArr(i).eval(input)
423
468
}
424
- res
469
+ return res
470
+ }
471
+
472
+ private def equalNullSafe (l : Any , r : Any ) = {
473
+ if (l == null && r == null ) {
474
+ true
475
+ } else if (l == null || r == null ) {
476
+ false
477
+ } else {
478
+ l == r
479
+ }
425
480
}
426
481
427
482
override def toString : String = {
428
- " CASE" + branches.sliding(2 , 2 ).map {
483
+ s " CASE $key " + branches.sliding(2 , 2 ).map {
429
484
case Seq (cond, value) => s " WHEN $cond THEN $value"
430
485
case Seq (elseValue) => s " ELSE $elseValue"
431
486
}.mkString
0 commit comments