Skip to content

Commit 3249528

Browse files
ueshinrxin
authored andcommitted
[SPARK-2196] [SQL] Fix nullability of CaseWhen.
`CaseWhen` should use `branches.length` to check if `elseValue` is provided or not. Author: Takuya UESHIN <[email protected]> Closes #1133 from ueshin/issues/SPARK-2196 and squashes the following commits: 510f12d [Takuya UESHIN] Add some tests. dc25e8d [Takuya UESHIN] Fix nullable of CaseWhen to be nullable if the elseValue is nullable. 4f049cc [Takuya UESHIN] Fix nullability of CaseWhen.
1 parent f46e02f commit 3249528

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
233233
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
234234
@transient private[this] lazy val values =
235235
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
236+
@transient private[this] lazy val elseValue =
237+
if (branches.length % 2 == 0) None else Option(branches.last)
236238

237239
override def nullable = {
238240
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
239-
values.exists(_.nullable) || (values.length % 2 == 0)
241+
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
240242
}
241243

242244
override lazy val resolved = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite {
333333
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
334334
}
335335

336+
test("case when") {
337+
val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
338+
val c1 = 'a.boolean.at(0)
339+
val c2 = 'a.boolean.at(1)
340+
val c3 = 'a.boolean.at(2)
341+
val c4 = 'a.string.at(3)
342+
val c5 = 'a.string.at(4)
343+
val c6 = 'a.string.at(5)
344+
345+
checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
346+
checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
347+
checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
348+
checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row)
349+
checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row)
350+
checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row)
351+
352+
checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
353+
checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
354+
checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
355+
checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
356+
357+
assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
358+
assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
359+
assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
360+
361+
val c4_notNull = 'a.boolean.notNull.at(3)
362+
val c5_notNull = 'a.boolean.notNull.at(4)
363+
val c6_notNull = 'a.boolean.notNull.at(5)
364+
365+
assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
366+
assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
367+
assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
368+
369+
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
370+
assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
371+
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
372+
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
373+
374+
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
375+
assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
376+
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
377+
}
378+
336379
test("complex type") {
337380
val row = new GenericRow(Array[Any](
338381
"^Ba*n", // 0

0 commit comments

Comments
 (0)