@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
24
24
import org .apache .spark .sql .catalyst .plans .logical ._
25
25
import org .apache .spark .sql .catalyst .rules ._
26
26
import org .apache .spark .sql .internal .SQLConf
27
+ import org .apache .spark .sql .types .{IntegerType , LongType }
27
28
28
29
class InferFiltersFromConstraintsSuite extends PlanTest {
29
30
@@ -46,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
46
47
y : LogicalPlan ,
47
48
expectedLeft : LogicalPlan ,
48
49
expectedRight : LogicalPlan ,
49
- joinType : JoinType ) = {
50
- val condition = Some (" x.a" .attr === " y.a" .attr)
50
+ joinType : JoinType ,
51
+ condition : Option [ Expression ] = Some (" x.a" .attr === " y.a" .attr)) = {
51
52
val originalQuery = x.join(y, joinType, condition).analyze
52
53
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
53
54
val optimized = Optimize .execute(originalQuery)
@@ -263,4 +264,56 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
263
264
val y = testRelation.subquery(' y )
264
265
testConstraintsAfterJoin(x, y, x.where(IsNotNull (' a )), y, RightOuter )
265
266
}
267
+
268
+ test(" Constraints should be inferred from cast equality constraint(filter higher data type)" ) {
269
+ val testRelation1 = LocalRelation (' a .int)
270
+ val testRelation2 = LocalRelation (' b .long)
271
+ val originalLeft = testRelation1.subquery(' left )
272
+ val originalRight = testRelation2.where(' b === 1L ).subquery(' right )
273
+
274
+ val left = testRelation1.where(IsNotNull (' a ) && ' a .cast(LongType ) === 1L ).subquery(' left )
275
+ val right = testRelation2.where(IsNotNull (' b ) && ' b === 1L ).subquery(' right )
276
+
277
+ Seq (Some (" left.a" .attr.cast(LongType ) === " right.b" .attr),
278
+ Some (" right.b" .attr === " left.a" .attr.cast(LongType ))).foreach { condition =>
279
+ testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner , condition)
280
+ }
281
+
282
+ Seq (Some (" left.a" .attr === " right.b" .attr.cast(IntegerType )),
283
+ Some (" right.b" .attr.cast(IntegerType ) === " left.a" .attr)).foreach { condition =>
284
+ testConstraintsAfterJoin(
285
+ originalLeft,
286
+ originalRight,
287
+ testRelation1.where(IsNotNull (' a )).subquery(' left ),
288
+ right,
289
+ Inner ,
290
+ condition)
291
+ }
292
+ }
293
+
294
+ test(" Constraints shouldn't be inferred from cast equality constraint(filter lower data type)" ) {
295
+ val testRelation1 = LocalRelation (' a .int)
296
+ val testRelation2 = LocalRelation (' b .long)
297
+ val originalLeft = testRelation1.where(' a === 1 ).subquery(' left )
298
+ val originalRight = testRelation2.subquery(' right )
299
+
300
+ val left = testRelation1.where(IsNotNull (' a ) && ' a === 1 ).subquery(' left )
301
+ val right = testRelation2.where(IsNotNull (' b )).subquery(' right )
302
+
303
+ Seq (Some (" left.a" .attr.cast(LongType ) === " right.b" .attr),
304
+ Some (" right.b" .attr === " left.a" .attr.cast(LongType ))).foreach { condition =>
305
+ testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner , condition)
306
+ }
307
+
308
+ Seq (Some (" left.a" .attr === " right.b" .attr.cast(IntegerType )),
309
+ Some (" right.b" .attr.cast(IntegerType ) === " left.a" .attr)).foreach { condition =>
310
+ testConstraintsAfterJoin(
311
+ originalLeft,
312
+ originalRight,
313
+ left,
314
+ testRelation2.where(IsNotNull (' b ) && ' b .attr.cast(IntegerType ) === 1 ).subquery(' right ),
315
+ Inner ,
316
+ condition)
317
+ }
318
+ }
266
319
}
0 commit comments