Skip to content

Commit 88e9d8b

Browse files
hongyzhangmingmwang
authored andcommitted
[SPARK-29231][SQL] Constraints should be inferred from cast equality constraint (#13)
### What changes were proposed in this pull request? This PR add support infer constraints from cast equality constraint. For example: ```scala scala> spark.sql("create table spark_29231_1(c1 bigint, c2 bigint)") res0: org.apache.spark.sql.DataFrame = [] scala> spark.sql("create table spark_29231_2(c1 int, c2 bigint)") res1: org.apache.spark.sql.DataFrame = [] scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain == Physical Plan == *(2) Project [c1#5L, c2#6L] +- *(2) BroadcastHashJoin [c1#5L], [cast(c1#7 as bigint)], Inner, BuildRight :- *(2) Project [c1#5L, c2#6L] : +- *(2) Filter (isnotnull(c1#5L) AND (c1#5L = 1)) : +- *(2) ColumnarToRow : +- FileScan parquet default.spark_29231_1[c1#5L,c2#6L] Batched: true, DataFilters: [isnotnull(c1#5L), (c1#5L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct<c1:bigint,c2:bigint> +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#209] +- *(1) Project [c1#7] +- *(1) Filter isnotnull(c1#7) +- *(1) ColumnarToRow +- FileScan parquet default.spark_29231_2[c1#7] Batched: true, DataFilters: [isnotnull(c1#7)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct<c1:int> ``` After this PR: ```scala scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain == Physical Plan == *(2) Project [c1#0L, c2#1L] +- *(2) BroadcastHashJoin [c1#0L], [cast(c1#2 as bigint)], Inner, BuildRight :- *(2) Project [c1#0L, c2#1L] : +- *(2) Filter (isnotnull(c1#0L) AND (c1#0L = 1)) : +- *(2) ColumnarToRow : +- FileScan parquet default.spark_29231_1[c1#0L,c2#1L] Batched: true, DataFilters: [isnotnull(c1#0L), (c1#0L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct<c1:bigint,c2:bigint> +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#99] +- *(1) Project [c1#2] +- *(1) Filter ((cast(c1#2 as bigint) = 1) AND isnotnull(c1#2)) +- *(1) ColumnarToRow +- FileScan parquet default.spark_29231_2[c1#2] Batched: true, DataFilters: [(cast(c1#2 as bigint) = 1), isnotnull(c1#2)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct<c1:int> ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Unit test. Closes #27252 from wangyum/SPARK-29231. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e933ef6 commit 88e9d8b

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,17 @@ trait ConstraintHelper {
6262
*/
6363
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
6464
var inferredConstraints = Set.empty[Expression]
65-
constraints.foreach {
65+
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
66+
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
67+
predicates.foreach {
6668
case eq @ EqualTo(l: Attribute, r: Attribute) =>
67-
val candidateConstraints = constraints - eq
69+
val candidateConstraints = predicates - eq
6870
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
6971
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
72+
case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) =>
73+
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
74+
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
75+
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
7076
case _ => // No inference
7177
}
7278
inferredConstraints -- constraints
@@ -75,7 +81,7 @@ trait ConstraintHelper {
7581
private def replaceConstraints(
7682
constraints: Set[Expression],
7783
source: Expression,
78-
destination: Attribute): Set[Expression] = constraints.map(_ transform {
84+
destination: Expression): Set[Expression] = constraints.map(_ transform {
7985
case e: Expression if e.semanticEquals(source) => destination
8086
})
8187

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
2626
import org.apache.spark.sql.internal.SQLConf
27+
import org.apache.spark.sql.types.{IntegerType, LongType}
2728

2829
class InferFiltersFromConstraintsSuite extends PlanTest {
2930

@@ -46,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
4647
y: LogicalPlan,
4748
expectedLeft: LogicalPlan,
4849
expectedRight: LogicalPlan,
49-
joinType: JoinType) = {
50-
val condition = Some("x.a".attr === "y.a".attr)
50+
joinType: JoinType,
51+
condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = {
5152
val originalQuery = x.join(y, joinType, condition).analyze
5253
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
5354
val optimized = Optimize.execute(originalQuery)
@@ -263,4 +264,56 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
263264
val y = testRelation.subquery('y)
264265
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
265266
}
267+
268+
test("Constraints should be inferred from cast equality constraint(filter higher data type)") {
269+
val testRelation1 = LocalRelation('a.int)
270+
val testRelation2 = LocalRelation('b.long)
271+
val originalLeft = testRelation1.subquery('left)
272+
val originalRight = testRelation2.where('b === 1L).subquery('right)
273+
274+
val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
275+
val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)
276+
277+
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
278+
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
279+
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
280+
}
281+
282+
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
283+
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
284+
testConstraintsAfterJoin(
285+
originalLeft,
286+
originalRight,
287+
testRelation1.where(IsNotNull('a)).subquery('left),
288+
right,
289+
Inner,
290+
condition)
291+
}
292+
}
293+
294+
test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") {
295+
val testRelation1 = LocalRelation('a.int)
296+
val testRelation2 = LocalRelation('b.long)
297+
val originalLeft = testRelation1.where('a === 1).subquery('left)
298+
val originalRight = testRelation2.subquery('right)
299+
300+
val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left)
301+
val right = testRelation2.where(IsNotNull('b)).subquery('right)
302+
303+
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
304+
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
305+
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
306+
}
307+
308+
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
309+
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
310+
testConstraintsAfterJoin(
311+
originalLeft,
312+
originalRight,
313+
left,
314+
testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
315+
Inner,
316+
condition)
317+
}
318+
}
266319
}

0 commit comments

Comments
 (0)