Skip to content

Commit a1a17b5

Browse files
authored
Fix inset for large queries. Disable record level filtering (apache#74)
1 parent 784072a commit a1a17b5

File tree

5 files changed

+85
-17
lines changed

5 files changed

+85
-17
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
<!-- Version used for internal directory structure -->
135135
<hive.version.short>1.2.1</hive.version.short>
136136
<derby.version>10.12.1.1</derby.version>
137-
<parquet.version>1.9.0-palantir3</parquet.version>
137+
<parquet.version>1.9.0-palantir4</parquet.version>
138138
<jetty.version>9.2.16.v20160414</jetty.version>
139139
<javaxservlet.version>3.1.0</javaxservlet.version>
140140
<chill.version>0.8.0</chill.version>

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER;
2323
import static org.apache.parquet.format.converter.ParquetMetadataConverter.range;
2424
import static org.apache.parquet.hadoop.ParquetFileReader.readFooter;
25+
import static org.apache.parquet.hadoop.ParquetInputFormat.DICTIONARY_FILTERING_ENABLED;
2526
import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter;
2627

2728
import com.google.common.collect.ImmutableList;
@@ -107,8 +108,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
107108
footer = readFooter(configuration, file, range(split.getStart(), split.getEnd()));
108109
FilterCompat.Filter filter = getFilter(configuration);
109110
this.reader = ParquetFileReader.open(configuration, file, footer);
111+
List<RowGroupFilter.FilterLevel> filterLevels =
112+
ImmutableList.of(RowGroupFilter.FilterLevel.STATISTICS);
113+
if (configuration.getBoolean(DICTIONARY_FILTERING_ENABLED, false)) {
114+
filterLevels = ImmutableList.of(RowGroupFilter.FilterLevel.STATISTICS,
115+
RowGroupFilter.FilterLevel.DICTIONARY);
116+
}
110117
blocks = filterRowGroups(
111-
ImmutableList.of(RowGroupFilter.FilterLevel.STATISTICS, RowGroupFilter.FilterLevel.DICTIONARY),
118+
filterLevels,
112119
filter,
113120
footer.getBlocks(),
114121
reader);

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,11 @@ class ParquetFileFormat
394394
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
395395
int96AsTimestamp)
396396

397+
// By default, disable record level filtering.
398+
if (hadoopConf.get(ParquetInputFormat.RECORD_FILTERING_ENABLED) == null) {
399+
hadoopConf.setBoolean(ParquetInputFormat.RECORD_FILTERING_ENABLED, false)
400+
}
401+
397402
// Try to push down filters when filter push-down is enabled.
398403
val pushed =
399404
if (sparkSession.sessionState.conf.parquetFilterPushDown) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,53 @@ import org.apache.spark.sql.types._
3030
*/
3131
private[parquet] object ParquetFilters {
3232

33+
case class SetInFilter[T <: Comparable[T]](valueSet: Set[T])
34+
extends UserDefinedPredicate[T] with Serializable {
35+
36+
override def keep(value: T): Boolean = {
37+
value != null && valueSet.contains(value)
38+
}
39+
40+
// Drop when no value in the set is within the statistics range.
41+
override def canDrop(statistics: Statistics[T]): Boolean = {
42+
val statMax = statistics.getMax
43+
val statMin = statistics.getMin
44+
val statRange = com.google.common.collect.Range.closed(statMin, statMax)
45+
!valueSet.exists(value => statRange.contains(value))
46+
}
47+
48+
// Can only drop not(in(set)) when we are know that every element in the block is in valueSet.
49+
// From the statistics, we can only be assured of this when min == max.
50+
override def inverseCanDrop(statistics: Statistics[T]): Boolean = {
51+
val statMax = statistics.getMax
52+
val statMin = statistics.getMin
53+
statMin == statMax && valueSet.contains(statMin)
54+
}
55+
}
56+
57+
private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = {
58+
case IntegerType =>
59+
(n: String, v: Set[Any]) =>
60+
FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]]))
61+
case LongType =>
62+
(n: String, v: Set[Any]) =>
63+
FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]]))
64+
case FloatType =>
65+
(n: String, v: Set[Any]) =>
66+
FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]]))
67+
case DoubleType =>
68+
(n: String, v: Set[Any]) =>
69+
FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]]))
70+
case StringType =>
71+
(n: String, v: Set[Any]) =>
72+
FilterApi.userDefined(binaryColumn(n),
73+
SetInFilter(v.map(s => Binary.fromString(s.asInstanceOf[String]))))
74+
case BinaryType =>
75+
(n: String, v: Set[Any]) =>
76+
FilterApi.userDefined(binaryColumn(n),
77+
SetInFilter(v.map(e => Binary.fromReusedByteArray(e.asInstanceOf[Array[Byte]]))))
78+
}
79+
3380
private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
3481
case BooleanType =>
3582
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
@@ -299,10 +346,7 @@ private[parquet] object ParquetFilters {
299346
.map(LogicalInverseRewriter.rewrite)
300347

301348
case sources.In(name, values) if dataTypeOf.contains(name) =>
302-
val eq = makeEq.lift(dataTypeOf(name))
303-
values.flatMap { v =>
304-
eq.map(_(name, v))
305-
}.reduceLeftOption(FilterApi.or)
349+
makeInSet.lift(dataTypeOf(name)).map(_(name, values.toSet))
306350

307351
case _ => None
308352
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.time.{LocalDate, ZoneId}
2424
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
2525
import org.apache.parquet.filter2.predicate.FilterApi._
2626
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}
27+
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
2728

2829
import org.apache.spark.sql._
2930
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -111,7 +112,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
111112
// Doesn't bother checking type parameters here (e.g. `Eq[Integer]`)
112113
maybeFilter.exists(_.getClass === filterClass)
113114
}
114-
checker(stripSparkFilter(query), expected)
115+
checker(query, expected)
115116
}
116117
}
117118
}
@@ -557,11 +558,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
557558
(1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
558559
val df = spark.read.parquet(path).filter("a = 2")
559560

560-
// The result should be single row.
561-
// When a filter is pushed to Parquet, Parquet can apply it to every row.
562-
// So, we can check the number of rows returned from the Parquet
563-
// to make sure our filter pushdown work.
564-
assert(stripSparkFilter(df).count == 1)
561+
assert(df.count == 1)
565562
}
566563
}
567564
}
@@ -676,20 +673,35 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
676673
val path = s"${dir.getCanonicalPath}/table1"
677674
(1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path)
678675
val df = spark.read.parquet(path).where("b in (0,2)")
679-
assert(stripSparkFilter(df).count == 3)
676+
assert(df.count == 3)
680677
val df1 = spark.read.parquet(path).where("not (b in (1))")
681-
assert(stripSparkFilter(df1).count == 3)
678+
assert(df1.count == 3)
682679
val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)")
683-
assert(stripSparkFilter(df2).count == 2)
680+
assert(df2.count == 2)
684681
val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)")
685-
assert(stripSparkFilter(df3).count == 4)
682+
assert(df3.count == 4)
686683
val df4 = spark.read.parquet(path).where("not (a <= 2)")
687-
assert(stripSparkFilter(df4).count == 3)
684+
assert(df4.count == 3)
688685
}
689686
}
690687
}
691688
}
692689

690+
test("Large In filters work with UDP") {
691+
import testImplicits._
692+
withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL",
693+
ParquetInputFormat.DICTIONARY_FILTERING_ENABLED -> "true") {
694+
withTempPath { dir =>
695+
val path = s"${dir.getCanonicalPath}/table1"
696+
(1 to 1000).toDF().write.parquet(path)
697+
val df = spark.read.parquet(path)
698+
val filter = (1 to 499).map(i => i.toString).mkString(",")
699+
assert(df.where(s"value in (${filter})").count() == 499)
700+
assert(df.where(s"value not in (${filter})").count() == 501)
701+
}
702+
}
703+
}
704+
693705
test("Do not create Timestamp filters when interpreting from INT96") {
694706
val baseMillis = System.currentTimeMillis()
695707
def base(): Timestamp = new Timestamp(baseMillis)

0 commit comments

Comments
 (0)