Skip to content

Commit 7e8e0fe

Browse files
committed
Minor Scala style fixes
1 parent 01e8721 commit 7e8e0fe

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,17 @@ private[spark] class HashShuffleReader[K, C](
5353
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
5454
val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
5555
CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, {
56-
// Close the stream once all the records have been read from it
56+
// Close the stream once all the records have been read from it to free underlying
57+
// ManagedBuffer as soon as possible. Note that in case of task failure, the task's
58+
// TaskCompletionListener will make sure this is released.
5759
wrappedStream.close()
5860
})
5961
}
6062

6163
// Update read metrics for each record materialized
62-
val iter = new InterruptibleIterator[Any](context, recordIterator) {
64+
val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
6365
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
64-
override def next(): Any = {
66+
override def next(): (Any, Any) = {
6567
readMetrics.incRecordsRead(1)
6668
delegate.next()
6769
}
@@ -70,14 +72,14 @@ private[spark] class HashShuffleReader[K, C](
7072
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
7173
if (dep.mapSideCombine) {
7274
// We are reading values that are already combined
73-
val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K,C)]]
75+
val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]]
7476
new InterruptibleIterator(context,
7577
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context))
7678
} else {
7779
// We don't know the value type, but also don't care -- the dependency *should*
7880
// have made sure its compatible w/ this aggregator, which will convert the value
7981
// type to the combined type C
80-
val keyValuesIterator = iter.asInstanceOf[Iterator[(K,Nothing)]]
82+
val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]]
8183
new InterruptibleIterator(context,
8284
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context))
8385
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.storage
2020
import java.io.InputStream
2121
import java.util.concurrent.LinkedBlockingQueue
2222

23-
import scala.collection.mutable
24-
import scala.collection.mutable.{ArrayBuffer, HashSet}
23+
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
2524
import scala.util.{Failure, Try}
2625

2726
import org.apache.spark.network.buffer.ManagedBuffer
@@ -96,7 +95,7 @@ final class ShuffleBlockFetcherIterator(
9695
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
9796
* the number of bytes in flight is limited to maxBytesInFlight.
9897
*/
99-
private[this] val fetchRequests = new mutable.Queue[FetchRequest]
98+
private[this] val fetchRequests = new Queue[FetchRequest]
10099

101100
/** Current bytes in flight from our requests */
102101
private[this] var bytesInFlight = 0L
@@ -275,6 +274,12 @@ final class ShuffleBlockFetcherIterator(
275274

276275
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
277276

277+
/**
278+
* Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
279+
* underlying each InputStream will be freed by the cleanup() method registered with the
280+
* TaskCompletionListener. However, callers should close() these InputStreams
281+
* as soon as they are no longer needed, in order to release memory as early as possible.
282+
*/
278283
override def next(): (BlockId, Try[InputStream]) = {
279284
numBlocksProcessed += 1
280285
val startFetchWait = System.currentTimeMillis()

0 commit comments

Comments
 (0)