Skip to content

Commit 5186da0

Browse files
committed
Revert "Add test to ensure HashShuffleReader is freeing resources"
This reverts commit f98a1b9.
1 parent f98a1b9 commit 5186da0

File tree

3 files changed

+10
-115
lines changed

3 files changed

+10
-115
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ import org.apache.spark._
2626
import org.apache.spark.shuffle.FetchFailedException
2727
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
2828

29-
private[hash] class BlockStoreShuffleFetcher extends Logging {
30-
29+
private[hash] object BlockStoreShuffleFetcher extends Logging {
3130
def fetchBlockStreams(
3231
shuffleId: Int,
3332
reduceId: Int,

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.storage.BlockManager
2120
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
2221
import org.apache.spark.serializer.Serializer
2322
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
@@ -28,19 +27,18 @@ private[spark] class HashShuffleReader[K, C](
2827
handle: BaseShuffleHandle[K, _, C],
2928
startPartition: Int,
3029
endPartition: Int,
31-
context: TaskContext,
32-
blockManager: BlockManager = SparkEnv.get.blockManager,
33-
blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher)
30+
context: TaskContext)
3431
extends ShuffleReader[K, C]
3532
{
3633
require(endPartition == startPartition + 1,
3734
"Hash shuffle currently only supports fetching one partition")
3835

3936
private val dep = handle.dependency
37+
private val blockManager = SparkEnv.get.blockManager
4038

4139
/** Read the combined key-values for this reduce task */
4240
override def read(): Iterator[Product2[K, C]] = {
43-
val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams(
41+
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
4442
handle.shuffleId, startPartition, context)
4543

4644
// Wrap the streams for compression based on configuration

core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala

Lines changed: 6 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,16 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import java.io._
21-
import java.nio.ByteBuffer
20+
import java.io.{File, FileWriter}
2221

2322
import scala.language.reflectiveCalls
2423

25-
import org.mockito.Matchers.any
26-
import org.mockito.Mockito._
27-
import org.mockito.invocation.InvocationOnMock
28-
import org.mockito.stubbing.Answer
29-
30-
import org.apache.spark._
31-
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}
24+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
25+
import org.apache.spark.executor.ShuffleWriteMetrics
3226
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
33-
import org.apache.spark.serializer._
34-
import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}
35-
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}
27+
import org.apache.spark.serializer.JavaSerializer
28+
import org.apache.spark.shuffle.FileShuffleBlockResolver
29+
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
3630

3731
class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
3832
private val testConf = new SparkConf(false)
@@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
113107
for (i <- 0 until numBytes) writer.write(i)
114108
writer.close()
115109
}
116-
117-
test("HashShuffleReader.read() releases resources and tracks metrics") {
118-
val shuffleId = 1
119-
val numMaps = 2
120-
val numKeyValuePairs = 10
121-
122-
val mockContext = mock(classOf[TaskContext])
123-
124-
val mockTaskMetrics = mock(classOf[TaskMetrics])
125-
val mockReadMetrics = mock(classOf[ShuffleReadMetrics])
126-
when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)
127-
when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)
128-
129-
val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])
130-
131-
val mockDep = mock(classOf[ShuffleDependency[_, _, _]])
132-
when(mockDep.keyOrdering).thenReturn(None)
133-
when(mockDep.aggregator).thenReturn(None)
134-
when(mockDep.serializer).thenReturn(Some(new Serializer {
135-
override def newInstance(): SerializerInstance = new SerializerInstance {
136-
137-
override def deserializeStream(s: InputStream): DeserializationStream =
138-
new DeserializationStream {
139-
override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]
140-
141-
override def close(): Unit = s.close()
142-
143-
private val values = {
144-
for (i <- 0 to numKeyValuePairs * 2) yield i
145-
}.iterator
146-
147-
private def getValueOrEOF(): Int = {
148-
if (values.hasNext) {
149-
values.next()
150-
} else {
151-
throw new EOFException("End of the file: mock deserializeStream")
152-
}
153-
}
154-
155-
// NOTE: the readKey and readValue methods are called by asKeyValueIterator()
156-
// which is wrapped in a NextIterator
157-
override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
158-
159-
override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
160-
}
161-
162-
override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =
163-
null.asInstanceOf[T]
164-
165-
override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)
166-
167-
override def serializeStream(s: OutputStream): SerializationStream =
168-
null.asInstanceOf[SerializationStream]
169-
170-
override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]
171-
}
172-
}))
173-
174-
val mockBlockManager = {
175-
// Create a block manager that isn't configured for compression, just returns input stream
176-
val blockManager = mock(classOf[BlockManager])
177-
when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))
178-
.thenAnswer(new Answer[InputStream] {
179-
override def answer(invocation: InvocationOnMock): InputStream = {
180-
val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]
181-
val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]
182-
inputStream
183-
}
184-
})
185-
blockManager
186-
}
187-
188-
val mockInputStream = mock(classOf[InputStream])
189-
when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))
190-
.thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))
191-
192-
val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)
193-
194-
val reader = new HashShuffleReader(shuffleHandle, 0, 1,
195-
mockContext, mockBlockManager, mockShuffleFetcher)
196-
197-
val values = reader.read()
198-
// Verify that we're reading the correct values
199-
var numValuesRead = 0
200-
for (((key: Int, value: Int), i) <- values.zipWithIndex) {
201-
assert(key == i * 2)
202-
assert(value == i * 2 + 1)
203-
numValuesRead += 1
204-
}
205-
// Verify that we read the correct number of values
206-
assert(numKeyValuePairs == numValuesRead)
207-
// Verify that our input stream was closed
208-
verify(mockInputStream, times(1)).close()
209-
// Verify that we collected metrics for each key/value pair
210-
verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)
211-
}
212110
}

0 commit comments

Comments
 (0)