Skip to content

Commit de7de28

Browse files
committed
Adds stricter rules for Parquet filters with null
1 parent 397d3aa commit de7de28

File tree

3 files changed

+173
-22
lines changed

3 files changed

+173
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ object Literal {
4141
}
4242
}
4343

44+
object NonNullLiteral {
45+
def unapply(literal: Literal): Option[(Any, DataType)] = {
46+
Option(literal.value).map(_ => (literal.value, literal.dataType))
47+
}
48+
}
49+
4450
/**
4551
* Extractor for retrieving Int literals.
4652
*/

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,34 @@ private[sql] object ParquetFilters {
5151
case DoubleType =>
5252
(n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
5353
case StringType =>
54-
(n: String, v: Any) =>
55-
FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
54+
(n: String, v: Any) => FilterApi.eq(
55+
binaryColumn(n),
56+
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
5657
case BinaryType =>
57-
(n: String, v: Any) =>
58-
FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
58+
(n: String, v: Any) => FilterApi.eq(
59+
binaryColumn(n),
60+
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
61+
}
62+
63+
val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
64+
case BooleanType =>
65+
(n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
66+
case IntegerType =>
67+
(n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer])
68+
case LongType =>
69+
(n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long])
70+
case FloatType =>
71+
(n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
72+
case DoubleType =>
73+
(n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
74+
case StringType =>
75+
(n: String, v: Any) => FilterApi.notEq(
76+
binaryColumn(n),
77+
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
78+
case BinaryType =>
79+
(n: String, v: Any) => FilterApi.notEq(
80+
binaryColumn(n),
81+
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
5982
}
6083

6184
val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -126,30 +149,45 @@ private[sql] object ParquetFilters {
126149
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
127150
}
128151

152+
// NOTE:
153+
//
154+
// For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
155+
// which can be casted to `false` implicitly. Please refer to the `eval` method of these
156+
// operators and the `SimplifyFilters` rule for details.
129157
predicate match {
130-
case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType =>
158+
case IsNull(NamedExpression(name, dataType)) =>
159+
makeEq.lift(dataType).map(_(name, null))
160+
case IsNotNull(NamedExpression(name, dataType)) =>
161+
makeNotEq.lift(dataType).map(_(name, null))
162+
163+
case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
131164
makeEq.lift(dataType).map(_(name, value))
132-
case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType =>
165+
case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
133166
makeEq.lift(dataType).map(_(name, value))
134167

135-
case LessThan(NamedExpression(name, _), Literal(value, dataType)) =>
168+
case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) =>
169+
makeNotEq.lift(dataType).map(_(name, value))
170+
case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) =>
171+
makeNotEq.lift(dataType).map(_(name, value))
172+
173+
case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
136174
makeLt.lift(dataType).map(_(name, value))
137-
case LessThan(Literal(value, dataType), NamedExpression(name, _)) =>
175+
case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
138176
makeGt.lift(dataType).map(_(name, value))
139177

140-
case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
178+
case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
141179
makeLtEq.lift(dataType).map(_(name, value))
142-
case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
180+
case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
143181
makeGtEq.lift(dataType).map(_(name, value))
144182

145-
case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) =>
183+
case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
146184
makeGt.lift(dataType).map(_(name, value))
147-
case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) =>
185+
case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
148186
makeLt.lift(dataType).map(_(name, value))
149187

150-
case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
188+
case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
151189
makeGtEq.lift(dataType).map(_(name, value))
152-
case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
190+
case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
153191
makeLtEq.lift(dataType).map(_(name, value))
154192

155193
case And(lhs, rhs) =>

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20-
import _root_.parquet.filter2.predicate.{FilterPredicate, Operators}
2120
import org.apache.hadoop.fs.{FileSystem, Path}
2221
import org.apache.hadoop.mapreduce.Job
2322
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
23+
import parquet.filter2.predicate.{FilterPredicate, Operators}
2424
import parquet.hadoop.ParquetFileWriter
2525
import parquet.hadoop.util.ContextUtil
26+
import parquet.io.api.Binary
2627

2728
import org.apache.spark.sql._
2829
import org.apache.spark.sql.catalyst.expressions._
@@ -85,6 +86,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
8586
TestData // Load test data tables.
8687

8788
var testRDD: SchemaRDD = null
89+
var originalParquetFilterPushdownEnabled = TestSQLContext.parquetFilterPushDown
8890

8991
override def beforeAll() {
9092
ParquetTestData.writeFile()
@@ -109,13 +111,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
109111
Utils.deleteRecursively(ParquetTestData.testNestedDir3)
110112
Utils.deleteRecursively(ParquetTestData.testNestedDir4)
111113
// here we should also unregister the table??
114+
115+
setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString)
112116
}
113117

114118
test("Read/Write All Types") {
115119
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
116120
val range = (0 to 255)
117-
val data = sparkContext.parallelize(range)
118-
.map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
121+
val data = sparkContext.parallelize(range).map { x =>
122+
parquet.AllDataTypes(
123+
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)
124+
}
119125

120126
data.saveAsParquetFile(tempDir)
121127

@@ -260,14 +266,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
260266
test("Read/Write All Types with non-primitive type") {
261267
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
262268
val range = (0 to 255)
263-
val data = sparkContext.parallelize(range)
264-
.map(x => AllDataTypesWithNonPrimitiveType(
269+
val data = sparkContext.parallelize(range).map { x =>
270+
parquet.AllDataTypesWithNonPrimitiveType(
265271
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
266272
(0 until x),
267273
(0 until x).map(Option(_).filter(_ % 3 == 0)),
268274
(0 until x).map(i => i -> i.toLong).toMap,
269275
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
270-
Data((0 until x), Nested(x, s"$x"))))
276+
parquet.Data((0 until x), parquet.Nested(x, s"$x")))
277+
}
271278
data.saveAsParquetFile(tempDir)
272279

273280
checkAnswer(
@@ -420,7 +427,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
420427
}
421428

422429
test("save and load case class RDD with nulls as parquet") {
423-
val data = NullReflectData(null, null, null, null, null)
430+
val data = parquet.NullReflectData(null, null, null, null, null)
424431
val rdd = sparkContext.parallelize(data :: Nil)
425432

426433
val file = getTempFilePath("parquet")
@@ -435,7 +442,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
435442
}
436443

437444
test("save and load case class RDD with Nones as parquet") {
438-
val data = OptionalReflectData(None, None, None, None, None)
445+
val data = parquet.OptionalReflectData(None, None, None, None, None)
439446
val rdd = sparkContext.parallelize(data :: Nil)
440447

441448
val file = getTempFilePath("parquet")
@@ -938,4 +945,104 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
938945
checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
939946
}
940947
}
948+
949+
def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = {
950+
val filter = ParquetFilters.createFilter(predicate)
951+
assert(filter.isDefined)
952+
assert(filter.get.getClass == filterClass)
953+
}
954+
955+
test("Pushdown IsNull predicate") {
956+
checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]])
957+
checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]])
958+
checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]])
959+
checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]])
960+
checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]])
961+
checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]])
962+
}
963+
964+
test("Pushdown IsNotNull predicate") {
965+
checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]])
966+
checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]])
967+
checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]])
968+
checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]])
969+
checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]])
970+
checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]])
971+
}
972+
973+
test("Pushdown EqualTo predicate") {
974+
checkFilter('a.int === 0, classOf[Operators.Eq[Integer]])
975+
checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]])
976+
checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]])
977+
checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]])
978+
checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]])
979+
checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]])
980+
}
981+
982+
test("Pushdown Not(EqualTo) predicate") {
983+
checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]])
984+
checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]])
985+
checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]])
986+
checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]])
987+
checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]])
988+
checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]])
989+
}
990+
991+
test("Pushdown LessThan predicate") {
992+
checkFilter('a.int < 0, classOf[Operators.Lt[Integer]])
993+
checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]])
994+
checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]])
995+
checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]])
996+
checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]])
997+
checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]])
998+
}
999+
1000+
test("Pushdown LessThanOrEqual predicate") {
1001+
checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]])
1002+
checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]])
1003+
checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]])
1004+
checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]])
1005+
checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]])
1006+
checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]])
1007+
}
1008+
1009+
test("Pushdown GreaterThan predicate") {
1010+
checkFilter('a.int > 0, classOf[Operators.Gt[Integer]])
1011+
checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]])
1012+
checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]])
1013+
checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]])
1014+
checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]])
1015+
checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]])
1016+
}
1017+
1018+
test("Pushdown GreaterThanOrEqual predicate") {
1019+
checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]])
1020+
checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]])
1021+
checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]])
1022+
checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]])
1023+
checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]])
1024+
checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]])
1025+
}
1026+
1027+
test("Comparison with null should not be pushed down") {
1028+
val predicates = Seq(
1029+
'a.int === null,
1030+
!('a.int === null),
1031+
1032+
Literal(null) === 'a.int,
1033+
!(Literal(null) === 'a.int),
1034+
1035+
'a.int < null,
1036+
'a.int <= null,
1037+
'a.int > null,
1038+
'a.int >= null,
1039+
1040+
Literal(null) < 'a.int,
1041+
Literal(null) <= 'a.int,
1042+
Literal(null) > 'a.int,
1043+
Literal(null) >= 'a.int
1044+
)
1045+
1046+
predicates.foreach(p => assert(ParquetFilters.createFilter(p).isEmpty))
1047+
}
9411048
}

0 commit comments

Comments
 (0)