Skip to content

Commit 27ea3d7

Browse files
Fix Constant Folding Bugs & Add More Unittests
1 parent b28e03a commit 27ea3d7

File tree

5 files changed

+182
-52
lines changed

5 files changed

+182
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,37 +114,37 @@ package object dsl {
114114
def attr = analysis.UnresolvedAttribute(s)
115115

116116
/** Creates a new AttributeReference of type boolean */
117-
def boolean = AttributeReference(s, BooleanType, nullable = false)()
117+
def boolean = AttributeReference(s, BooleanType, nullable = true)()
118118

119119
/** Creates a new AttributeReference of type byte */
120-
def byte = AttributeReference(s, ByteType, nullable = false)()
120+
def byte = AttributeReference(s, ByteType, nullable = true)()
121121

122122
/** Creates a new AttributeReference of type short */
123-
def short = AttributeReference(s, ShortType, nullable = false)()
123+
def short = AttributeReference(s, ShortType, nullable = true)()
124124

125125
/** Creates a new AttributeReference of type int */
126-
def int = AttributeReference(s, IntegerType, nullable = false)()
126+
def int = AttributeReference(s, IntegerType, nullable = true)()
127127

128128
/** Creates a new AttributeReference of type long */
129-
def long = AttributeReference(s, LongType, nullable = false)()
129+
def long = AttributeReference(s, LongType, nullable = true)()
130130

131131
/** Creates a new AttributeReference of type float */
132-
def float = AttributeReference(s, FloatType, nullable = false)()
132+
def float = AttributeReference(s, FloatType, nullable = true)()
133133

134134
/** Creates a new AttributeReference of type double */
135-
def double = AttributeReference(s, DoubleType, nullable = false)()
135+
def double = AttributeReference(s, DoubleType, nullable = true)()
136136

137137
/** Creates a new AttributeReference of type string */
138-
def string = AttributeReference(s, StringType, nullable = false)()
138+
def string = AttributeReference(s, StringType, nullable = true)()
139139

140140
/** Creates a new AttributeReference of type decimal */
141-
def decimal = AttributeReference(s, DecimalType, nullable = false)()
141+
def decimal = AttributeReference(s, DecimalType, nullable = true)()
142142

143143
/** Creates a new AttributeReference of type timestamp */
144-
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
144+
def timestamp = AttributeReference(s, TimestampType, nullable = true)()
145145

146146
/** Creates a new AttributeReference of type binary */
147-
def binary = AttributeReference(s, BinaryType, nullable = false)()
147+
def binary = AttributeReference(s, BinaryType, nullable = true)()
148148
}
149149

150150
implicit class DslAttribute(a: AttributeReference) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,6 @@ abstract class Expression extends TreeNode[Expression] {
222222
}
223223
}
224224

225-
/**
226-
* Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if
227-
* either one of its operands is null. Exceptional case requires to update the optimization rule
228-
* at [[optimizer.ConstantFolding ConstantFolding]]
229-
*/
230225
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
231226
self: Product =>
232227

@@ -243,11 +238,6 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
243238
self: Product =>
244239
}
245240

246-
/**
247-
* Root class for rewritten single operand UDF expression. By default, we assume it produces Null
248-
* if its operand is null. Exceptional case requires to update the optimization rule
249-
* at [[optimizer.ConstantFolding ConstantFolding]]
250-
*/
251241
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
252242
self: Product =>
253243

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,33 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
4141
override def toString = s"$child[$ordinal]"
4242

4343
override def eval(input: Row): Any = {
44-
if (child.dataType.isInstanceOf[ArrayType]) {
45-
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
46-
val o = ordinal.eval(input).asInstanceOf[Int]
47-
if (baseValue == null) {
48-
null
49-
} else if (o >= baseValue.size || o < 0) {
50-
null
51-
} else {
52-
baseValue(o)
53-
}
44+
val value = child.eval(input)
45+
if(value == null) {
46+
null
5447
} else {
55-
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
5648
val key = ordinal.eval(input)
57-
if (baseValue == null) {
49+
if(key == null) {
5850
null
5951
} else {
60-
baseValue.get(key).orNull
52+
if (child.dataType.isInstanceOf[ArrayType]) {
53+
val baseValue = value.asInstanceOf[Seq[_]]
54+
val o = key.asInstanceOf[Int]
55+
if (baseValue == null) {
56+
null
57+
} else if (o >= baseValue.size || o < 0) {
58+
null
59+
} else {
60+
baseValue(o)
61+
}
62+
} else {
63+
val baseValue = value.asInstanceOf[Map[Any, _]]
64+
val key = ordinal.eval(input)
65+
if (baseValue == null) {
66+
null
67+
} else {
68+
baseValue.get(key).orNull
69+
}
70+
}
6171
}
6272
}
6373
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ object NullPropagation extends Rule[LogicalPlan] {
9696
case q: LogicalPlan => q transformExpressionsUp {
9797
// Skip redundant folding of literals.
9898
case l: Literal => l
99-
case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
99+
case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
100100
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
101101
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
102-
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
103-
case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType)
102+
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
103+
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
104104
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
105105
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
106106
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
@@ -122,13 +122,32 @@ object NullPropagation extends Rule[LogicalPlan] {
122122
case Literal(candidate, _) if(candidate == v) => true
123123
case _ => false
124124
})) => Literal(true, BooleanType)
125-
// Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant
126-
// null operand) before here.
127-
case e: UnaryExpression => e.child match {
125+
case e: UnaryMinus => e.child match {
128126
case Literal(null, _) => Literal(null, e.dataType)
129127
case _ => e
130128
}
131-
case e: BinaryExpression => e.children match {
129+
case e: Cast => e.child match {
130+
case Literal(null, _) => Literal(null, e.dataType)
131+
case _ => e
132+
}
133+
case e: Not => e.child match {
134+
case Literal(null, _) => Literal(null, e.dataType)
135+
case _ => e
136+
}
137+
case e: And => e // leave it for BooleanSimplification
138+
case e: Or => e // leave it for BooleanSimplification
139+
// Put exceptional cases above
140+
case e: BinaryArithmetic => e.children match {
141+
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
142+
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
143+
case _ => e
144+
}
145+
case e: BinaryPredicate => e.children match {
146+
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
147+
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
148+
case _ => e
149+
}
150+
case e: StringRegexExpression => e.children match {
132151
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
133152
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
134153
case _ => e

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {
129129

130130
test("LIKE literal Regular Expression") {
131131
checkEvaluation(Literal(null, StringType).like("a"), null)
132+
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
132133
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
133134
checkEvaluation("abdef" like "abdef", true)
134135
checkEvaluation("a_%b" like "a\\__b", true)
@@ -157,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
157158
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
158159
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
159160
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
161+
162+
checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
160163
}
161164

162165
test("RLIKE literal Regular Expression") {
166+
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
167+
checkEvaluation("abdef" rlike Literal(null, StringType), null)
168+
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
163169
checkEvaluation("abdef" rlike "abdef", true)
164170
checkEvaluation("abbbbc" rlike "a.*c", true)
165171

@@ -244,17 +250,19 @@ class ExpressionEvaluationSuite extends FunSuite {
244250

245251
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
246252

247-
assert(("abcdef" cast StringType).nullable === false)
248-
assert(("abcdef" cast BinaryType).nullable === false)
249-
assert(("abcdef" cast BooleanType).nullable === false)
250-
assert(("abcdef" cast TimestampType).nullable === true)
251-
assert(("abcdef" cast LongType).nullable === true)
252-
assert(("abcdef" cast IntegerType).nullable === true)
253-
assert(("abcdef" cast ShortType).nullable === true)
254-
assert(("abcdef" cast ByteType).nullable === true)
255-
assert(("abcdef" cast DecimalType).nullable === true)
256-
assert(("abcdef" cast DoubleType).nullable === true)
257-
assert(("abcdef" cast FloatType).nullable === true)
253+
checkEvaluation(("abcdef" cast StringType).nullable, false)
254+
checkEvaluation(("abcdef" cast BinaryType).nullable,false)
255+
checkEvaluation(("abcdef" cast BooleanType).nullable, false)
256+
checkEvaluation(("abcdef" cast TimestampType).nullable, true)
257+
checkEvaluation(("abcdef" cast LongType).nullable, true)
258+
checkEvaluation(("abcdef" cast IntegerType).nullable, true)
259+
checkEvaluation(("abcdef" cast ShortType).nullable, true)
260+
checkEvaluation(("abcdef" cast ByteType).nullable, true)
261+
checkEvaluation(("abcdef" cast DecimalType).nullable, true)
262+
checkEvaluation(("abcdef" cast DoubleType).nullable, true)
263+
checkEvaluation(("abcdef" cast FloatType).nullable, true)
264+
265+
checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
258266
}
259267

260268
test("timestamp") {
@@ -285,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
285293
// A test for higher precision than millis
286294
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
287295
}
296+
297+
test("null checking") {
298+
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
299+
val c1 = 'a.string.at(0)
300+
val c2 = 'a.string.at(1)
301+
val c3 = 'a.boolean.at(2)
302+
val c4 = 'a.boolean.at(3)
303+
304+
checkEvaluation(IsNull(c1), false, row)
305+
checkEvaluation(IsNotNull(c1), true, row)
306+
307+
checkEvaluation(IsNull(c2), true, row)
308+
checkEvaluation(IsNotNull(c2), false, row)
309+
310+
checkEvaluation(IsNull(Literal(1, ShortType)), false)
311+
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
312+
313+
checkEvaluation(IsNull(Literal(null, ShortType)), true)
314+
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
315+
316+
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
317+
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
318+
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
319+
320+
checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
321+
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
322+
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
323+
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
324+
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
325+
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
326+
checkEvaluation(If(Literal(false, BooleanType),
327+
Literal("a", StringType), Literal("b", StringType)), "b", row)
328+
329+
checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
330+
checkEvaluation(In(Literal("^Ba*n", StringType),
331+
Literal("^Ba*n", StringType) :: Nil), true, row)
332+
checkEvaluation(In(Literal("^Ba*n", StringType),
333+
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
334+
}
335+
336+
test("complex type") {
337+
val row = new GenericRow(Array[Any](
338+
"^Ba*n", // 0
339+
null.asInstanceOf[String], // 1
340+
new GenericRow(Array[Any]("aa", "bb")), // 2
341+
Map("aa"->"bb"), // 3
342+
Seq("aa", "bb") // 4
343+
))
344+
345+
val typeS = StructType(
346+
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
347+
)
348+
val typeMap = MapType(StringType, StringType)
349+
val typeArray = ArrayType(StringType)
350+
351+
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
352+
Literal("aa")), "bb", row)
353+
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
354+
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
355+
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
356+
Literal(null, StringType)), null, row)
357+
358+
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
359+
Literal(1)), "bb", row)
360+
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
361+
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
362+
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
363+
Literal(null, IntegerType)), null, row)
364+
365+
checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
366+
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
367+
}
368+
369+
test("arithmetic") {
370+
val row = new GenericRow(Array[Any](1, 2, 3, null))
371+
val c1 = 'a.int.at(0)
372+
val c2 = 'a.int.at(1)
373+
val c3 = 'a.int.at(2)
374+
val c4 = 'a.int.at(3)
375+
376+
checkEvaluation(UnaryMinus(c1), -1, row)
377+
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)
378+
379+
checkEvaluation(Add(c1, c4), null, row)
380+
checkEvaluation(Add(c1, c2), 3, row)
381+
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
382+
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
383+
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
384+
}
385+
386+
test("BinaryComparison") {
387+
val row = new GenericRow(Array[Any](1, 2, 3, null))
388+
val c1 = 'a.int.at(0)
389+
val c2 = 'a.int.at(1)
390+
val c3 = 'a.int.at(2)
391+
val c4 = 'a.int.at(3)
392+
393+
checkEvaluation(LessThan(c1, c4), null, row)
394+
checkEvaluation(LessThan(c1, c2), true, row)
395+
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
396+
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
397+
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
398+
}
288399
}
289400

0 commit comments

Comments
 (0)