Skip to content

[SPARK-14338][SQL] Improve SimplifyConditionals rule to handle null in IF/CASEWHEN #12122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def nonNullLiteral(e: Expression): Boolean = e match {
private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
Expand Down Expand Up @@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
private def falseOrNullLiteral(e: Expression): Boolean = e match {
case FalseLiteral => true
case Literal(null, _) => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue

case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
val newBranches = branches.filter(_._1 != FalseLiteral)
val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Expand All @@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal(null, NullType), Literal(30))

test("simplify if") {
assertEquivalent(
Expand All @@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
assertEquivalent(
If(FalseLiteral, Literal(10), Literal(20)),
Literal(20))

assertEquivalent(
If(Literal(null, NullType), Literal(10), Literal(20)),
Literal(20))
}

test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}

test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
Literal(30))

assertEquivalent(
Expand All @@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
CaseWhen(trueBranch :: normalBranch :: Nil, None),
CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
Literal(5))

// Test branch elimination and simplification in combination
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
:: Nil, None),
Literal(5))

// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
Expand Down