Skip to content

[SPARK-4493][SQL] Don't pushdown Eq, NotEq, Lt, LtEq, Gt and GtEq predicates with nulls for Parquet #3367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ object Literal {
}
}

/**
* An extractor that matches non-null literal values
*/
object NonNullLiteral {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc please

def unapply(literal: Literal): Option[(Any, DataType)] = {
Option(literal.value).map(_ => (literal.value, literal.dataType))
}
}

/**
* Extractor for retrieving Int literals.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,37 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])

// Binary.fromString and Binary.fromByteArray don't accept null values
case StringType =>
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
case BinaryType =>
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary.fromString and Binary.fromByteArray don't accept null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add this as a comment.

}

val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
case IntegerType =>
(n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
(n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
(n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
}

val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand Down Expand Up @@ -126,30 +151,45 @@ private[sql] object ParquetFilters {
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}

// NOTE:
//
// For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
// which can be casted to `false` implicitly. Please refer to the `eval` method of these
// operators and the `SimplifyFilters` rule for details.
predicate match {
case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType =>
case IsNull(NamedExpression(name, dataType)) =>
makeEq.lift(dataType).map(_(name, null))
case IsNotNull(NamedExpression(name, dataType)) =>
makeNotEq.lift(dataType).map(_(name, null))

case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeEq.lift(dataType).map(_(name, value))
case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType =>
case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeEq.lift(dataType).map(_(name, value))

case LessThan(NamedExpression(name, _), Literal(value, dataType)) =>
case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) =>
makeNotEq.lift(dataType).map(_(name, value))
case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) =>
makeNotEq.lift(dataType).map(_(name, value))

case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeLt.lift(dataType).map(_(name, value))
case LessThan(Literal(value, dataType), NamedExpression(name, _)) =>
case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeGt.lift(dataType).map(_(name, value))

case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeLtEq.lift(dataType).map(_(name, value))
case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeGtEq.lift(dataType).map(_(name, value))

case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) =>
case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeGt.lift(dataType).map(_(name, value))
case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) =>
case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeLt.lift(dataType).map(_(name, value))

case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeGtEq.lift(dataType).map(_(name, value))
case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeLtEq.lift(dataType).map(_(name, value))

case And(lhs, rhs) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.parquet

import _root_.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
import parquet.filter2.predicate.{FilterPredicate, Operators}
import parquet.hadoop.ParquetFileWriter
import parquet.hadoop.util.ContextUtil
import parquet.io.api.Binary

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -84,7 +85,8 @@ case class NumericData(i: Int, d: Double)
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.

var testRDD: SchemaRDD = null
private var testRDD: SchemaRDD = null
private val originalParquetFilterPushdownEnabled = TestSQLContext.parquetFilterPushDown

override def beforeAll() {
ParquetTestData.writeFile()
Expand All @@ -109,13 +111,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
Utils.deleteRecursively(ParquetTestData.testNestedDir3)
Utils.deleteRecursively(ParquetTestData.testNestedDir4)
// here we should also unregister the table??

setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString)
}

test("Read/Write All Types") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
val data = sparkContext.parallelize(range)
.map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
val data = sparkContext.parallelize(range).map { x =>
parquet.AllDataTypes(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)
}

data.saveAsParquetFile(tempDir)

Expand Down Expand Up @@ -260,14 +266,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types with non-primitive type") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
val data = sparkContext.parallelize(range)
.map(x => AllDataTypesWithNonPrimitiveType(
val data = sparkContext.parallelize(range).map { x =>
parquet.AllDataTypesWithNonPrimitiveType(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
(0 until x),
(0 until x).map(Option(_).filter(_ % 3 == 0)),
(0 until x).map(i => i -> i.toLong).toMap,
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
Data((0 until x), Nested(x, s"$x"))))
parquet.Data((0 until x), parquet.Nested(x, s"$x")))
}
data.saveAsParquetFile(tempDir)

checkAnswer(
Expand Down Expand Up @@ -420,7 +427,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}

test("save and load case class RDD with nulls as parquet") {
val data = NullReflectData(null, null, null, null, null)
val data = parquet.NullReflectData(null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)

val file = getTempFilePath("parquet")
Expand All @@ -435,7 +442,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}

test("save and load case class RDD with Nones as parquet") {
val data = OptionalReflectData(None, None, None, None, None)
val data = parquet.OptionalReflectData(None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)

val file = getTempFilePath("parquet")
Expand Down Expand Up @@ -938,4 +945,108 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
}
}

def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = {
val filter = ParquetFilters.createFilter(predicate)
assert(filter.isDefined)
assert(filter.get.getClass == filterClass)
}

test("Pushdown IsNull predicate") {
checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]])
checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]])
checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]])
checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]])
checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]])
checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]])
}

test("Pushdown IsNotNull predicate") {
checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]])
checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]])
checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]])
checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]])
checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]])
checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]])
}

test("Pushdown EqualTo predicate") {
checkFilter('a.int === 0, classOf[Operators.Eq[Integer]])
checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]])
checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]])
checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]])
checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]])
checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]])
}

test("Pushdown Not(EqualTo) predicate") {
checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]])
checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]])
checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]])
checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]])
checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]])
checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]])
}

test("Pushdown LessThan predicate") {
checkFilter('a.int < 0, classOf[Operators.Lt[Integer]])
checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]])
checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]])
checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]])
checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]])
checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]])
}

test("Pushdown LessThanOrEqual predicate") {
checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]])
checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]])
checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]])
checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]])
checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]])
checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]])
}

test("Pushdown GreaterThan predicate") {
checkFilter('a.int > 0, classOf[Operators.Gt[Integer]])
checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]])
checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]])
checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]])
checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]])
checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]])
}

test("Pushdown GreaterThanOrEqual predicate") {
checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]])
checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]])
checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]])
checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]])
checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]])
checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]])
}

test("Comparison with null should not be pushed down") {
val predicates = Seq(
'a.int === null,
!('a.int === null),

Literal(null) === 'a.int,
!(Literal(null) === 'a.int),

'a.int < null,
'a.int <= null,
'a.int > null,
'a.int >= null,

Literal(null) < 'a.int,
Literal(null) <= 'a.int,
Literal(null) > 'a.int,
Literal(null) >= 'a.int
)

predicates.foreach { p =>
assert(
ParquetFilters.createFilter(p).isEmpty,
"Comparison predicate with null shouldn't be pushed down")
}
}
}