|
17 | 17 |
|
18 | 18 | package org.apache.spark.shuffle.hash
|
19 | 19 |
|
20 |
| -import java.io._ |
21 |
| -import java.nio.ByteBuffer |
| 20 | +import java.io.{File, FileWriter} |
22 | 21 |
|
23 | 22 | import scala.language.reflectiveCalls
|
24 | 23 |
|
25 |
| -import org.mockito.Matchers.any |
26 |
| -import org.mockito.Mockito._ |
27 |
| -import org.mockito.invocation.InvocationOnMock |
28 |
| -import org.mockito.stubbing.Answer |
29 |
| - |
30 |
| -import org.apache.spark._ |
31 |
| -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} |
| 24 | +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} |
| 25 | +import org.apache.spark.executor.ShuffleWriteMetrics |
32 | 26 | import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
|
33 |
| -import org.apache.spark.serializer._ |
34 |
| -import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} |
35 |
| -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} |
| 27 | +import org.apache.spark.serializer.JavaSerializer |
| 28 | +import org.apache.spark.shuffle.FileShuffleBlockResolver |
| 29 | +import org.apache.spark.storage.{ShuffleBlockId, FileSegment} |
36 | 30 |
|
37 | 31 | class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
|
38 | 32 | private val testConf = new SparkConf(false)
|
@@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
|
113 | 107 | for (i <- 0 until numBytes) writer.write(i)
|
114 | 108 | writer.close()
|
115 | 109 | }
|
116 |
| - |
117 |
| - test("HashShuffleReader.read() releases resources and tracks metrics") { |
118 |
| - val shuffleId = 1 |
119 |
| - val numMaps = 2 |
120 |
| - val numKeyValuePairs = 10 |
121 |
| - |
122 |
| - val mockContext = mock(classOf[TaskContext]) |
123 |
| - |
124 |
| - val mockTaskMetrics = mock(classOf[TaskMetrics]) |
125 |
| - val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) |
126 |
| - when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) |
127 |
| - when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) |
128 |
| - |
129 |
| - val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) |
130 |
| - |
131 |
| - val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) |
132 |
| - when(mockDep.keyOrdering).thenReturn(None) |
133 |
| - when(mockDep.aggregator).thenReturn(None) |
134 |
| - when(mockDep.serializer).thenReturn(Some(new Serializer { |
135 |
| - override def newInstance(): SerializerInstance = new SerializerInstance { |
136 |
| - |
137 |
| - override def deserializeStream(s: InputStream): DeserializationStream = |
138 |
| - new DeserializationStream { |
139 |
| - override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] |
140 |
| - |
141 |
| - override def close(): Unit = s.close() |
142 |
| - |
143 |
| - private val values = { |
144 |
| - for (i <- 0 to numKeyValuePairs * 2) yield i |
145 |
| - }.iterator |
146 |
| - |
147 |
| - private def getValueOrEOF(): Int = { |
148 |
| - if (values.hasNext) { |
149 |
| - values.next() |
150 |
| - } else { |
151 |
| - throw new EOFException("End of the file: mock deserializeStream") |
152 |
| - } |
153 |
| - } |
154 |
| - |
155 |
| - // NOTE: the readKey and readValue methods are called by asKeyValueIterator() |
156 |
| - // which is wrapped in a NextIterator |
157 |
| - override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] |
158 |
| - |
159 |
| - override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] |
160 |
| - } |
161 |
| - |
162 |
| - override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = |
163 |
| - null.asInstanceOf[T] |
164 |
| - |
165 |
| - override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) |
166 |
| - |
167 |
| - override def serializeStream(s: OutputStream): SerializationStream = |
168 |
| - null.asInstanceOf[SerializationStream] |
169 |
| - |
170 |
| - override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] |
171 |
| - } |
172 |
| - })) |
173 |
| - |
174 |
| - val mockBlockManager = { |
175 |
| - // Create a block manager that isn't configured for compression, just returns input stream |
176 |
| - val blockManager = mock(classOf[BlockManager]) |
177 |
| - when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) |
178 |
| - .thenAnswer(new Answer[InputStream] { |
179 |
| - override def answer(invocation: InvocationOnMock): InputStream = { |
180 |
| - val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] |
181 |
| - val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] |
182 |
| - inputStream |
183 |
| - } |
184 |
| - }) |
185 |
| - blockManager |
186 |
| - } |
187 |
| - |
188 |
| - val mockInputStream = mock(classOf[InputStream]) |
189 |
| - when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) |
190 |
| - .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) |
191 |
| - |
192 |
| - val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) |
193 |
| - |
194 |
| - val reader = new HashShuffleReader(shuffleHandle, 0, 1, |
195 |
| - mockContext, mockBlockManager, mockShuffleFetcher) |
196 |
| - |
197 |
| - val values = reader.read() |
198 |
| - // Verify that we're reading the correct values |
199 |
| - var numValuesRead = 0 |
200 |
| - for (((key: Int, value: Int), i) <- values.zipWithIndex) { |
201 |
| - assert(key == i * 2) |
202 |
| - assert(value == i * 2 + 1) |
203 |
| - numValuesRead += 1 |
204 |
| - } |
205 |
| - // Verify that we read the correct number of values |
206 |
| - assert(numKeyValuePairs == numValuesRead) |
207 |
| - // Verify that our input stream was closed |
208 |
| - verify(mockInputStream, times(1)).close() |
209 |
| - // Verify that we collected metrics for each key/value pair |
210 |
| - verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) |
211 |
| - } |
212 | 110 | }
|
0 commit comments