Skip to content

Commit f93841e

Browse files
committed
Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher.
This commit also includes Scala style cleanup.
1 parent 7e8e0fe commit f93841e

File tree

4 files changed

+21
-26
lines changed

4 files changed

+21
-26
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try}
2525
import org.apache.spark._
2626
import org.apache.spark.shuffle.FetchFailedException
2727
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
28-
import org.apache.spark.util.CompletionIterator
2928

3029
private[shuffle] object BlockStoreShuffleFetcher extends Logging {
3130
def fetchBlockStreams(
@@ -80,10 +79,6 @@ private[shuffle] object BlockStoreShuffleFetcher extends Logging {
8079
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
8180
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
8281

83-
val itr = blockFetcherItr.map(unpackBlock)
84-
85-
CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, {
86-
context.taskMetrics().updateShuffleReadMetrics()
87-
})
82+
blockFetcherItr.map(unpackBlock)
8883
}
8984
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20+
import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator}
2021
import org.apache.spark.serializer.Serializer
21-
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
22+
import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle}
2223
import org.apache.spark.util.CompletionIterator
2324
import org.apache.spark.util.collection.ExternalSorter
24-
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
2525

2626
private[spark] class HashShuffleReader[K, C](
2727
handle: BaseShuffleHandle[K, _, C],
@@ -51,24 +51,22 @@ private[spark] class HashShuffleReader[K, C](
5151

5252
// Create a key/value iterator for each stream
5353
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
54-
val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
55-
CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, {
56-
// Close the stream once all the records have been read from it to free underlying
57-
// ManagedBuffer as soon as possible. Note that in case of task failure, the task's
58-
// TaskCompletionListener will make sure this is released.
59-
wrappedStream.close()
60-
})
54+
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
6155
}
6256

57+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
6358
// Update read metrics for each record materialized
64-
val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
65-
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
66-
override def next(): (Any, Any) = {
67-
readMetrics.incRecordsRead(1)
68-
delegate.next()
69-
}
59+
val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
60+
override def next(): (Any, Any) = {
61+
readMetrics.incRecordsRead(1)
62+
delegate.next()
63+
}
7064
}
7165

66+
val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, {
67+
context.taskMetrics().updateShuffleReadMetrics()
68+
})
69+
7270
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
7371
if (dep.mapSideCombine) {
7472
// We are reading values that are already combined

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ import java.util.concurrent.LinkedBlockingQueue
2323
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
2424
import scala.util.{Failure, Try}
2525

26+
import org.apache.spark.{Logging, TaskContext}
2627
import org.apache.spark.network.buffer.ManagedBuffer
2728
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2829
import org.apache.spark.util.Utils
29-
import org.apache.spark.{Logging, TaskContext}
3030

3131
/**
3232
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -306,16 +306,18 @@ final class ShuffleBlockFetcherIterator(
306306
// not exist, SPARK-4085). In that case, we should propagate the right exception so
307307
// the scheduler gets a FetchFailedException.
308308
Try(buf.createInputStream()).map { inputStream =>
309-
new WrappedInputStream(inputStream, this)
309+
new BufferReleasingInputStream(inputStream, this)
310310
}
311311
}
312312

313313
(result.blockId, iteratorTry)
314314
}
315315
}
316316

317-
// Helper class that ensures a ManagerBuffer is released upon InputStream.close()
318-
private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator)
317+
/** Helper class that ensures a ManagerBuffer is released upon InputStream.close() */
318+
private class BufferReleasingInputStream(
319+
delegate: InputStream,
320+
iterator: ShuffleBlockFetcherIterator)
319321
extends InputStream {
320322
private var closed = false
321323

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
115115

116116
// Make sure we release buffers when a wrapped input stream is closed.
117117
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
118-
val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator)
118+
val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator)
119119
verify(mockBuf, times(0)).release()
120120
wrappedInputStream.close()
121121
verify(mockBuf, times(1)).release()

0 commit comments

Comments
 (0)