Skip to content

[SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin #10653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object EquiJoinSelection extends Strategy with PredicateHelper {

private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
side: joins.BuildSide): Seq[SparkPlan] = {
val broadcastHashJoin = execution.joins.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

// --- Inner joins --------------------------------------------------------------------------

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
joins.SortMergeJoin(
leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil

// --- Outer joins --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import java.util.NoSuchElementException

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -29,6 +31,7 @@ trait HashJoin {
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan

Expand All @@ -50,6 +53,12 @@ trait HashJoin {
protected def streamSideKeyGenerator: Projection =
UnsafeProjection.create(streamedKeys, streamedPlan.output)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(r: InternalRow) => true
}

protected def hashJoin(
streamIter: Iterator[InternalRow],
numStreamRows: LongSQLMetric,
Expand All @@ -68,44 +77,52 @@ trait HashJoin {

private[this] val joinKeys = streamSideKeyGenerator

override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())
override final def hasNext: Boolean = {
while (true) {
// check if it's end of current matches
if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
currentHashMatches = null
currentMatchPosition = -1
}

override final def next(): InternalRow = {
val ret = buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
currentMatchPosition += 1
numOutputRows += 1
resultProjection(ret)
}
// find the next match
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
if (currentHashMatches != null) {
currentMatchPosition = 0
}
}
}
if (currentHashMatches == null) {
return false
}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false if the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
// found some matches
buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
if (boundCondition(joinRow)) {
return true
} else {
currentMatchPosition += 1
}
}
false // unreachable
}

if (currentHashMatches == null) {
false
override final def next(): InternalRow = {
// next() could be called without calling hasNext()
if (hasNext) {
currentMatchPosition += 1
numOutputRows += 1
resultProjection(joinRow)
} else {
currentMatchPosition = 0
true
throw new NoSuchElementException
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ trait HashOuterJoin {

@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not related to this PR, just make it to aligned with others (a little faster for empty condition).

newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(row: InternalRow) => true
}

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

Expand Down Expand Up @@ -64,6 +65,13 @@ case class SortMergeJoin(
val numOutputRows = longMetric("numOutputRows")

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
newPredicate(cond, left.output ++ right.output)
}.getOrElse {
(r: InternalRow) => true
}
}
new RowIterator {
// The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
Expand All @@ -89,26 +97,34 @@ case class SortMergeJoin(
private[this] val resultProjection: (InternalRow) => InternalRow =
UnsafeProjection.create(schema)

if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
}

override def advanceNext(): Boolean = {
if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
} else {
currentRightMatches = null
currentLeftRow = null
currentMatchIdx = -1
while (currentMatchIdx >= 0) {
if (currentMatchIdx == currentRightMatches.length) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
} else {
currentRightMatches = null
currentLeftRow = null
currentMatchIdx = -1
return false
}
}
}
if (currentLeftRow != null) {
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
currentMatchIdx += 1
numOutputRows += 1
true
} else {
false
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
false
}

override def getRow: InternalRow = resultProjection(joinRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLConf}

class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder
Expand Down Expand Up @@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
}

def makeSortMergeJoin(
Expand All @@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
EnsureRequirements(sqlContext).apply(sortMergeJoin)
}

test(s"$testName using BroadcastHashJoin (build=left)") {
Expand Down