Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 7bac2fe

Browse files
massiekayousterhout
authored andcommitted
[SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader
This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, InputStream). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are retrieved. Author: Matt Massie <[email protected]> Author: Kay Ousterhout <[email protected]> Closes apache#6423 from massie/shuffle-api-cleanup and squashes the following commits: 8b0632c [Matt Massie] Minor Scala style fixes d0a1b39 [Matt Massie] Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup 290f1eb [Kay Ousterhout] Added test for HashShuffleReader.read() 5186da0 [Kay Ousterhout] Revert "Add test to ensure HashShuffleReader is freeing resources" f98a1b9 [Matt Massie] Add test to ensure HashShuffleReader is freeing resources a011bfa [Matt Massie] Use PrivateMethodTester on check that delegate stream is closed 4ea1712 [Matt Massie] Small code cleanup for readability 7429a98 [Matt Massie] Update tests to check that BufferReleasingStream is closing delegate InputStream f458489 [Matt Massie] Remove unnecessary map() on return Iterator 4abb855 [Matt Massie] Consolidate metric code. Make it clear why InterrubtibleIterator is needed. 5c30405 [Matt Massie] Return visibility of BlockStoreShuffleFetcher to private[hash] 7eedd1d [Matt Massie] Small Scala import cleanup 28f8085 [Matt Massie] Small import nit f93841e [Matt Massie] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher. 7e8e0fe [Matt Massie] Minor Scala style fixes 01e8721 [Matt Massie] Explicitly cast iterator in branches for type clarity 7c8f73e [Matt Massie] Close Block InputStream immediately after all records are read 208b7a5 [Matt Massie] Small code style changes b70c945 [Matt Massie] Make BlockStoreShuffleFetcher visible to shuffle package 19135f2 [Matt Massie] [SPARK-7884] Allow Spark shuffle APIs to be more customizable
1 parent 82f80c1 commit 7bac2fe

File tree

5 files changed

+314
-96
lines changed

5 files changed

+314
-96
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import scala.collection.mutable.ArrayBuffer
21-
import scala.collection.mutable.HashMap
22-
import scala.util.{Failure, Success, Try}
20+
import java.io.InputStream
21+
22+
import scala.collection.mutable.{ArrayBuffer, HashMap}
23+
import scala.util.{Failure, Success}
2324

2425
import org.apache.spark._
25-
import org.apache.spark.serializer.Serializer
2626
import org.apache.spark.shuffle.FetchFailedException
27-
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
28-
import org.apache.spark.util.CompletionIterator
27+
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
28+
ShuffleBlockId}
2929

3030
private[hash] object BlockStoreShuffleFetcher extends Logging {
31-
def fetch[T](
31+
def fetchBlockStreams(
3232
shuffleId: Int,
3333
reduceId: Int,
3434
context: TaskContext,
35-
serializer: Serializer)
36-
: Iterator[T] =
35+
blockManager: BlockManager,
36+
mapOutputTracker: MapOutputTracker)
37+
: Iterator[(BlockId, InputStream)] =
3738
{
3839
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
39-
val blockManager = SparkEnv.get.blockManager
4040

4141
val startTime = System.currentTimeMillis
42-
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
42+
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
4343
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
4444
shuffleId, reduceId, System.currentTimeMillis - startTime))
4545

@@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
5353
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
5454
}
5555

56-
def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
56+
val blockFetcherItr = new ShuffleBlockFetcherIterator(
57+
context,
58+
blockManager.shuffleClient,
59+
blockManager,
60+
blocksByAddress,
61+
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
62+
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
63+
64+
// Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
65+
blockFetcherItr.map { blockPair =>
5766
val blockId = blockPair._1
5867
val blockOption = blockPair._2
5968
blockOption match {
60-
case Success(block) => {
61-
block.asInstanceOf[Iterator[T]]
69+
case Success(inputStream) => {
70+
(blockId, inputStream)
6271
}
6372
case Failure(e) => {
6473
blockId match {
@@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
7281
}
7382
}
7483
}
75-
76-
val blockFetcherItr = new ShuffleBlockFetcherIterator(
77-
context,
78-
SparkEnv.get.blockManager.shuffleClient,
79-
blockManager,
80-
blocksByAddress,
81-
serializer,
82-
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
83-
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
84-
val itr = blockFetcherItr.flatMap(unpackBlock)
85-
86-
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
87-
context.taskMetrics.updateShuffleReadMetrics()
88-
})
89-
90-
new InterruptibleIterator[T](context, completionIter) {
91-
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
92-
override def next(): T = {
93-
readMetrics.incRecordsRead(1)
94-
delegate.next()
95-
}
96-
}
9784
}
9885
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.{InterruptibleIterator, TaskContext}
20+
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
2121
import org.apache.spark.serializer.Serializer
2222
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
23+
import org.apache.spark.storage.BlockManager
24+
import org.apache.spark.util.CompletionIterator
2325
import org.apache.spark.util.collection.ExternalSorter
2426

2527
private[spark] class HashShuffleReader[K, C](
2628
handle: BaseShuffleHandle[K, _, C],
2729
startPartition: Int,
2830
endPartition: Int,
29-
context: TaskContext)
31+
context: TaskContext,
32+
blockManager: BlockManager = SparkEnv.get.blockManager,
33+
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
3034
extends ShuffleReader[K, C]
3135
{
3236
require(endPartition == startPartition + 1,
@@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](
3640

3741
/** Read the combined key-values for this reduce task */
3842
override def read(): Iterator[Product2[K, C]] = {
43+
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
44+
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
45+
46+
// Wrap the streams for compression based on configuration
47+
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
48+
blockManager.wrapForCompression(blockId, inputStream)
49+
}
50+
3951
val ser = Serializer.getSerializer(dep.serializer)
40-
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
52+
val serializerInstance = ser.newInstance()
53+
54+
// Create a key/value iterator for each stream
55+
val recordIter = wrappedStreams.flatMap { wrappedStream =>
56+
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
57+
// NextIterator. The NextIterator makes sure that close() is called on the
58+
// underlying InputStream when all records have been read.
59+
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
60+
}
61+
62+
// Update the context task metrics for each record read.
63+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
64+
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
65+
recordIter.map(record => {
66+
readMetrics.incRecordsRead(1)
67+
record
68+
}),
69+
context.taskMetrics().updateShuffleReadMetrics())
70+
71+
// An interruptible iterator must be used here in order to support task cancellation
72+
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
4173

4274
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
4375
if (dep.mapSideCombine) {
44-
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
76+
// We are reading values that are already combined
77+
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
78+
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
4579
} else {
46-
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
80+
// We don't know the value type, but also don't care -- the dependency *should*
81+
// have made sure its compatible w/ this aggregator, which will convert the value
82+
// type to the combined type C
83+
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
84+
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
4785
}
4886
} else {
4987
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
50-
51-
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
52-
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
88+
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
5389
}
5490

5591
// Sort the output if there is a sort ordering defined.

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@
1717

1818
package org.apache.spark.storage
1919

20+
import java.io.InputStream
2021
import java.util.concurrent.LinkedBlockingQueue
2122

2223
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
2324
import scala.util.{Failure, Try}
2425

2526
import org.apache.spark.{Logging, TaskContext}
26-
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2727
import org.apache.spark.network.buffer.ManagedBuffer
28-
import org.apache.spark.serializer.{SerializerInstance, Serializer}
29-
import org.apache.spark.util.{CompletionIterator, Utils}
28+
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
29+
import org.apache.spark.util.Utils
3030

3131
/**
3232
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
3333
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
3434
*
35-
* This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
36-
* pipelined fashion as they are received.
35+
* This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
36+
* in a pipelined fashion as they are received.
3737
*
3838
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
3939
* using too much memory.
@@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
4444
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
4545
* For each block we also require the size (in bytes as a long field) in
4646
* order to throttle the memory usage.
47-
* @param serializer serializer used to deserialize the data.
4847
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
4948
*/
5049
private[spark]
@@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
5352
shuffleClient: ShuffleClient,
5453
blockManager: BlockManager,
5554
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
56-
serializer: Serializer,
5755
maxBytesInFlight: Long)
58-
extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
56+
extends Iterator[(BlockId, Try[InputStream])] with Logging {
5957

6058
import ShuffleBlockFetcherIterator._
6159

@@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(
8381

8482
/**
8583
* A queue to hold our results. This turns the asynchronous model provided by
86-
* [[BlockTransferService]] into a synchronous model (iterator).
84+
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
8785
*/
8886
private[this] val results = new LinkedBlockingQueue[FetchResult]
8987

@@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
102100
/** Current bytes in flight from our requests */
103101
private[this] var bytesInFlight = 0L
104102

105-
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
106-
107-
private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
103+
private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
108104

109105
/**
110106
* Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(
114110

115111
initialize()
116112

117-
/**
118-
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
119-
*/
120-
private[this] def cleanup() {
121-
isZombie = true
113+
// Decrements the buffer reference count.
114+
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
115+
private[storage] def releaseCurrentResultBuffer(): Unit = {
122116
// Release the current buffer if necessary
123117
currentResult match {
124118
case SuccessFetchResult(_, _, buf) => buf.release()
125119
case _ =>
126120
}
121+
currentResult = null
122+
}
127123

124+
/**
125+
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
126+
*/
127+
private[this] def cleanup() {
128+
isZombie = true
129+
releaseCurrentResultBuffer()
128130
// Release buffers in the results queue
129131
val iter = results.iterator()
130132
while (iter.hasNext) {
@@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(
272274

273275
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
274276

275-
override def next(): (BlockId, Try[Iterator[Any]]) = {
277+
/**
278+
* Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
279+
* underlying each InputStream will be freed by the cleanup() method registered with the
280+
* TaskCompletionListener. However, callers should close() these InputStreams
281+
* as soon as they are no longer needed, in order to release memory as early as possible.
282+
*/
283+
override def next(): (BlockId, Try[InputStream]) = {
276284
numBlocksProcessed += 1
277285
val startFetchWait = System.currentTimeMillis()
278286
currentResult = results.take()
@@ -290,29 +298,55 @@ final class ShuffleBlockFetcherIterator(
290298
sendRequest(fetchRequests.dequeue())
291299
}
292300

293-
val iteratorTry: Try[Iterator[Any]] = result match {
301+
val iteratorTry: Try[InputStream] = result match {
294302
case FailureFetchResult(_, e) =>
295303
Failure(e)
296304
case SuccessFetchResult(blockId, _, buf) =>
297305
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
298306
// not exist, SPARK-4085). In that case, we should propagate the right exception so
299307
// the scheduler gets a FetchFailedException.
300-
Try(buf.createInputStream()).map { is0 =>
301-
val is = blockManager.wrapForCompression(blockId, is0)
302-
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
303-
CompletionIterator[Any, Iterator[Any]](iter, {
304-
// Once the iterator is exhausted, release the buffer and set currentResult to null
305-
// so we don't release it again in cleanup.
306-
currentResult = null
307-
buf.release()
308-
})
308+
Try(buf.createInputStream()).map { inputStream =>
309+
new BufferReleasingInputStream(inputStream, this)
309310
}
310311
}
311312

312313
(result.blockId, iteratorTry)
313314
}
314315
}
315316

317+
/**
318+
* Helper class that ensures a ManagedBuffer is release upon InputStream.close()
319+
*/
320+
private class BufferReleasingInputStream(
321+
private val delegate: InputStream,
322+
private val iterator: ShuffleBlockFetcherIterator)
323+
extends InputStream {
324+
private[this] var closed = false
325+
326+
override def read(): Int = delegate.read()
327+
328+
override def close(): Unit = {
329+
if (!closed) {
330+
delegate.close()
331+
iterator.releaseCurrentResultBuffer()
332+
closed = true
333+
}
334+
}
335+
336+
override def available(): Int = delegate.available()
337+
338+
override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
339+
340+
override def skip(n: Long): Long = delegate.skip(n)
341+
342+
override def markSupported(): Boolean = delegate.markSupported()
343+
344+
override def read(b: Array[Byte]): Int = delegate.read(b)
345+
346+
override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
347+
348+
override def reset(): Unit = delegate.reset()
349+
}
316350

317351
private[storage]
318352
object ShuffleBlockFetcherIterator {

0 commit comments

Comments
 (0)