Skip to content

Commit cfe012a

Browse files
c21maropu
authored andcommitted
[SPARK-32629][SQL] Track metrics of BitSet/OpenHashSet in full outer SHJ
### What changes were proposed in this pull request? This is followup from #29342, where to do two things: * Per #29342 (comment), change from java `HashSet` to spark in-house `OpenHashSet` to track matched rows for non-unique join keys. I checked `OpenHashSet` implementation which is built from a key index (`OpenHashSet._bitset` as `BitSet`) and key array (`OpenHashSet._data` as `Array`). Java `HashSet` is built from `HashMap`, which stores value in `Node` linked list and by theory should have taken more memory than `OpenHashSet`. Reran the same benchmark query used in #29342, and verified the query has similar performance here between `HashSet` and `OpenHashSet`. * Track metrics of the extra data structure `BitSet`/`OpenHashSet` for full outer SHJ. This depends on above thing, because there seems no easy way to get java `HashSet` memory size. ### Why are the changes needed? To better surface the memory usage for full outer SHJ more accurately. This can help users/developers to debug/improve full outer SHJ. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unite test in `SQLMetricsSuite.scala` . Closes #29566 from c21/add-metrics. Authored-by: Cheng Su <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent ccc0250 commit cfe012a

File tree

3 files changed

+73
-25
lines changed

3 files changed

+73
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.joins
1919

2020
import java.util.concurrent.TimeUnit._
2121

22-
import scala.collection.mutable
23-
2422
import org.apache.spark.TaskContext
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.sql.catalyst.InternalRow
@@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._
3129
import org.apache.spark.sql.catalyst.plans.physical._
3230
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
3331
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
34-
import org.apache.spark.util.collection.BitSet
32+
import org.apache.spark.util.collection.{BitSet, OpenHashSet}
3533

3634
/**
3735
* Performs a hash join of two child relations by first shuffling the data using the join keys.
@@ -136,10 +134,10 @@ case class ShuffledHashJoinExec(
136134
* Full outer shuffled hash join with unique join keys:
137135
* 1. Process rows from stream side by looking up hash relation.
138136
* Mark the matched rows from build side be looked up.
139-
* A `BitSet` is used to track matched rows with key index.
137+
* A bit set is used to track matched rows with key index.
140138
* 2. Process rows from build side by iterating hash relation.
141139
* Filter out rows from build side being matched already,
142-
* by checking key index from `BitSet`.
140+
* by checking key index from bit set.
143141
*/
144142
private def fullOuterJoinWithUniqueKey(
145143
streamIter: Iterator[InternalRow],
@@ -150,9 +148,8 @@ case class ShuffledHashJoinExec(
150148
streamNullJoinRowWithBuild: => InternalRow => JoinedRow,
151149
buildNullRow: GenericInternalRow,
152150
streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
153-
// TODO(SPARK-32629):record metrics of extra BitSet/HashSet
154-
// in full outer shuffled hash join
155151
val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
152+
longMetric("buildDataSize") += matchedKeys.capacity / 8
156153

157154
// Process stream side with looking up hash relation
158155
val streamResultIter = streamIter.map { srow =>
@@ -198,11 +195,11 @@ case class ShuffledHashJoinExec(
198195
* Full outer shuffled hash join with non-unique join keys:
199196
* 1. Process rows from stream side by looking up hash relation.
200197
* Mark the matched rows from build side be looked up.
201-
* A `HashSet[Long]` is used to track matched rows with
198+
* A [[OpenHashSet]] (Long) is used to track matched rows with
202199
* key index (Int) and value index (Int) together.
203200
* 2. Process rows from build side by iterating hash relation.
204201
* Filter out rows from build side being matched already,
205-
* by checking key index and value index from `HashSet`.
202+
* by checking key index and value index from [[OpenHashSet]].
206203
*
207204
* The "value index" is defined as the index of the tuple in the chain
208205
* of tuples having the same key. For example, if certain key is found thrice,
@@ -218,9 +215,15 @@ case class ShuffledHashJoinExec(
218215
streamNullJoinRowWithBuild: => InternalRow => JoinedRow,
219216
buildNullRow: GenericInternalRow,
220217
streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
221-
// TODO(SPARK-32629):record metrics of extra BitSet/HashSet
222-
// in full outer shuffled hash join
223-
val matchedRows = new mutable.HashSet[Long]
218+
val matchedRows = new OpenHashSet[Long]
219+
TaskContext.get().addTaskCompletionListener[Unit](_ => {
220+
// At the end of the task, update the task's memory usage for this
221+
// [[OpenHashSet]] to track matched rows, which has two parts:
222+
// [[OpenHashSet._bitset]] and [[OpenHashSet._data]].
223+
val bitSetEstimatedSize = matchedRows.getBitSet.capacity / 8
224+
val dataEstimatedSize = matchedRows.capacity * 8
225+
longMetric("buildDataSize") += bitSetEstimatedSize + dataEstimatedSize
226+
})
224227

225228
def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
226229
val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeSt
2929
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
3030
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3131
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
32+
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
3233
import org.apache.spark.sql.functions._
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.test.SharedSparkSession
@@ -363,6 +364,41 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
363364
}
364365
}
365366

367+
test("SPARK-32629: ShuffledHashJoin(full outer) metrics") {
368+
val uniqueLeftDf = Seq(("1", "1"), ("11", "11")).toDF("key", "value")
369+
val nonUniqueLeftDf = Seq(("1", "1"), ("1", "2"), ("11", "11")).toDF("key", "value")
370+
val rightDf = (1 to 10).map(i => (i.toString, i.toString)).toDF("key2", "value")
371+
Seq(
372+
// Test unique key on build side
373+
(uniqueLeftDf, rightDf, 11, 134228048, 10, 134221824),
374+
// Test non-unique key on build side
375+
(nonUniqueLeftDf, rightDf, 12, 134228552, 11, 134221824)
376+
).foreach { case (leftDf, rightDf, fojRows, fojBuildSize, rojRows, rojBuildSize) =>
377+
val fojDf = leftDf.hint("shuffle_hash").join(
378+
rightDf, $"key" === $"key2", "full_outer")
379+
fojDf.collect()
380+
val fojPlan = fojDf.queryExecution.executedPlan.collectFirst {
381+
case s: ShuffledHashJoinExec => s
382+
}
383+
assert(fojPlan.isDefined, "The query plan should have shuffled hash join")
384+
testMetricsInSparkPlanOperator(fojPlan.get,
385+
Map("numOutputRows" -> fojRows, "buildDataSize" -> fojBuildSize))
386+
387+
// Test right outer join as well to verify build data size to be different
388+
// from full outer join. This makes sure we take extra BitSet/OpenHashSet
389+
// for full outer join into account.
390+
val rojDf = leftDf.hint("shuffle_hash").join(
391+
rightDf, $"key" === $"key2", "right_outer")
392+
rojDf.collect()
393+
val rojPlan = rojDf.queryExecution.executedPlan.collectFirst {
394+
case s: ShuffledHashJoinExec => s
395+
}
396+
assert(rojPlan.isDefined, "The query plan should have shuffled hash join")
397+
testMetricsInSparkPlanOperator(rojPlan.get,
398+
Map("numOutputRows" -> rojRows, "buildDataSize" -> rojBuildSize))
399+
}
400+
}
401+
366402
test("BroadcastHashJoin(outer) metrics") {
367403
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
368404
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
@@ -686,16 +722,6 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
686722
}
687723

688724
test("SPARK-28332: SQLMetric merge should handle -1 properly") {
689-
def checkSparkPlanMetrics(plan: SparkPlan, expected: Map[String, Long]): Unit = {
690-
expected.foreach { case (metricName: String, metricValue: Long) =>
691-
assert(plan.metrics.contains(metricName), s"The query plan should have metric $metricName")
692-
val actualMetric = plan.metrics.get(metricName).get
693-
assert(actualMetric.value == metricValue,
694-
s"The query plan metric $metricName did not match, " +
695-
s"expected:$metricValue, actual:${actualMetric.value}")
696-
}
697-
}
698-
699725
val df = testData.join(testData2.filter('b === 0), $"key" === $"a", "left_outer")
700726
df.collect()
701727
val plan = df.queryExecution.executedPlan
@@ -706,7 +732,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
706732

707733
assert(exchanges.size == 2, "The query plan should have two shuffle exchanges")
708734

709-
checkSparkPlanMetrics(exchanges(0), Map("dataSize" -> 3200, "shuffleRecordsWritten" -> 100))
710-
checkSparkPlanMetrics(exchanges(1), Map("dataSize" -> 0, "shuffleRecordsWritten" -> 0))
735+
testMetricsInSparkPlanOperator(exchanges.head,
736+
Map("dataSize" -> 3200, "shuffleRecordsWritten" -> 100))
737+
testMetricsInSparkPlanOperator(exchanges(1), Map("dataSize" -> 0, "shuffleRecordsWritten" -> 0))
711738
}
712739
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.TestUtils
2525
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
2626
import org.apache.spark.sql.DataFrame
2727
import org.apache.spark.sql.catalyst.TableIdentifier
28-
import org.apache.spark.sql.execution.SparkPlanInfo
28+
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo}
2929
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore}
3030
import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED
3131
import org.apache.spark.sql.test.SQLTestUtils
@@ -254,6 +254,24 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
254254
}
255255
}
256256
}
257+
258+
/**
259+
* Verify if the metrics in `SparkPlan` operator are same as expected metrics.
260+
*
261+
* @param plan `SparkPlan` operator to check metrics
262+
* @param expectedMetrics the expected metrics. The format is `metric name -> metric value`.
263+
*/
264+
protected def testMetricsInSparkPlanOperator(
265+
plan: SparkPlan,
266+
expectedMetrics: Map[String, Long]): Unit = {
267+
expectedMetrics.foreach { case (metricName: String, metricValue: Long) =>
268+
assert(plan.metrics.contains(metricName), s"The query plan should have metric $metricName")
269+
val actualMetric = plan.metrics(metricName)
270+
assert(actualMetric.value == metricValue,
271+
s"The query plan metric $metricName did not match, " +
272+
s"expected:$metricValue, actual:${actualMetric.value}")
273+
}
274+
}
257275
}
258276

259277

0 commit comments

Comments
 (0)