Skip to content

Commit 16b77bf

Browse files
committed
Fixed pruning predication conjunctions and disjunctions
1 parent 16195c5 commit 16b77bf

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ private[sql] case class InMemoryColumnarTableScan(
113113
import org.apache.spark.sql.catalyst.expressions._
114114

115115
val buildFilter: PartialFunction[Expression, Expression] = {
116-
case And(lhs: Expression, rhs: Expression) =>
116+
case And(lhs: Expression, rhs: Expression)
117+
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
117118
buildFilter(lhs) && buildFilter(rhs)
118119

119-
case Or(lhs: Expression, rhs: Expression) =>
120+
case Or(lhs: Expression, rhs: Expression)
121+
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
120122
buildFilter(lhs) || buildFilter(rhs)
121123

122124
case EqualTo(a: AttributeReference, l: Literal) =>

sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,27 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
4646
uncacheTable("intData")
4747
}
4848

49+
// Comparisons
4950
checkBatchPruning("i = 1", Seq(1), 1, 1)
5051
checkBatchPruning("1 = i", Seq(1), 1, 1)
51-
5252
checkBatchPruning("i < 12", 1 to 11, 1, 2)
5353
checkBatchPruning("i <= 11", 1 to 11, 1, 2)
5454
checkBatchPruning("i > 88", 89 to 100, 1, 2)
5555
checkBatchPruning("i >= 89", 89 to 100, 1, 2)
56-
5756
checkBatchPruning("12 > i", 1 to 11, 1, 2)
5857
checkBatchPruning("11 >= i", 1 to 11, 1, 2)
5958
checkBatchPruning("88 < i", 89 to 100, 1, 2)
6059
checkBatchPruning("89 <= i", 89 to 100, 1, 2)
6160

61+
// Conjunction and disjunction
6262
checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
6363
checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
6464
checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
6565

66+
// With unsupported predicate
67+
checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
68+
checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
69+
6670
def checkBatchPruning(
6771
filter: String,
6872
expectedQueryResult: Seq[Int],

0 commit comments

Comments
 (0)