Skip to content

Commit 271c85e

Browse files
authored
Merge pull request apache#41 from palantir/pw/parquetInRewritw
SPARK-17091: ParquetFilters rewrite IN to OR of Eq
2 parents 4e358e9 + edab4ad commit 271c85e

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ private[parquet] object ParquetFilters {
234234
case sources.Not(pred) =>
235235
createFilter(schema, pred).map(FilterApi.not)
236236

237+
case sources.In(name, values) if dataTypeOf.contains(name) =>
238+
values.flatMap { v =>
239+
makeEq.lift(dataTypeOf(name)).map(_(name, v))
240+
}.reduceLeftOption(FilterApi.or)
241+
237242
case _ => None
238243
}
239244
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
4040
* NOTE:
4141
*
4242
* 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the
43-
* `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)`
44-
* results in a `GtEq` filter predicate rather than a `Not`.
43+
* `BooleanSimplification` optimization rule whenever possible. As a result, predicate
44+
* `!(a < 1)` results in a `GtEq` filter predicate rather than a `Not`.
4545
*
4646
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
4747
* data type is nullable.
@@ -369,7 +369,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
369369

370370
test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") {
371371
import testImplicits._
372-
Seq("true", "false").map { vectorized =>
372+
Seq("true", "false").foreach { vectorized =>
373373
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
374374
SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true",
375375
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
@@ -535,25 +535,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
535535
import testImplicits._
536536

537537
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
538-
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
538+
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
539539
withTempPath { dir =>
540540
val path = s"${dir.getCanonicalPath}/table"
541541
(1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
542542

543-
Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
544-
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
545-
val accu = new LongAccumulator
546-
accu.register(sparkContext, Some("numRowGroups"))
547-
548-
val df = spark.read.parquet(path).filter("a < 100")
549-
df.foreachPartition(_.foreach(v => accu.add(0)))
550-
df.collect
551-
552-
val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
553-
assert(numRowGroups.isDefined)
554-
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
555-
AccumulatorContext.remove(accu.id)
543+
Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0))
544+
.foreach { case (push, func) =>
545+
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
546+
val accu = new LongAccumulator
547+
accu.register(sparkContext, Some("numRowGroups"))
548+
549+
val df = spark.read.parquet(path).filter("a < 100")
550+
df.foreachPartition(_.foreach(v => accu.add(0)))
551+
df.collect
552+
553+
val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
554+
assert(numRowGroups.isDefined)
555+
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
556+
AccumulatorContext.remove(accu.id)
557+
}
556558
}
559+
}
560+
}
561+
}
562+
563+
test("In filters are pushed down") {
564+
import testImplicits._
565+
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
566+
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
567+
withTempPath { dir =>
568+
val path = s"${dir.getCanonicalPath}/table1"
569+
(1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path)
570+
val df = spark.read.parquet(path).where("b in (0,2)")
571+
assert(stripSparkFilter(df).count == 3)
572+
val df1 = spark.read.parquet(path).where("not (b in (1))")
573+
assert(stripSparkFilter(df1).count == 3)
574+
val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)")
575+
assert(stripSparkFilter(df2).count == 2)
576+
val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)")
577+
assert(stripSparkFilter(df3).count == 4)
578+
val df4 = spark.read.parquet(path).where("not (a <= 2)")
579+
assert(stripSparkFilter(df4).count == 3)
557580
}
558581
}
559582
}

0 commit comments

Comments
 (0)