17
17
18
18
package org .apache .spark .shuffle
19
19
20
+ import java .util .concurrent .CountDownLatch
21
+ import java .util .concurrent .atomic .AtomicInteger
22
+
23
+ import org .mockito .Mockito ._
20
24
import org .scalatest .concurrent .Timeouts
21
25
import org .scalatest .time .SpanSugar ._
22
- import java .util .concurrent .atomic .AtomicBoolean
23
- import java .util .concurrent .CountDownLatch
24
26
25
- import org .apache .spark .SparkFunSuite
27
+ import org .apache .spark .{ SparkFunSuite , TaskContext }
26
28
27
29
class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
30
+
31
+ val nextTaskAttemptId = new AtomicInteger ()
32
+
28
33
/** Launch a thread with the given body block and return it. */
29
34
private def startThread (name : String )(body : => Unit ): Thread = {
30
35
val thread = new Thread (" ShuffleMemorySuite " + name) {
31
36
override def run () {
32
- body
37
+ try {
38
+ val taskAttemptId = nextTaskAttemptId.getAndIncrement
39
+ val mockTaskContext = mock(classOf [TaskContext ], RETURNS_SMART_NULLS )
40
+ when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
41
+ TaskContext .setTaskContext(mockTaskContext)
42
+ body
43
+ } finally {
44
+ TaskContext .unset()
45
+ }
33
46
}
34
47
}
35
48
thread.start()
36
49
thread
37
50
}
38
51
39
- test(" single thread requesting memory" ) {
52
+ test(" single task requesting memory" ) {
40
53
val manager = new ShuffleMemoryManager (1000L )
41
54
42
55
assert(manager.tryToAcquire(100L ) === 100L )
@@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
107
120
}
108
121
109
122
110
- test(" threads cannot grow past 1 / N" ) {
111
- // Two threads request 250 bytes first, wait for each other to get it, and then request
123
+ test(" tasks cannot grow past 1 / N" ) {
124
+ // Two tasks request 250 bytes first, wait for each other to get it, and then request
112
125
// 500 more; we should only grant 250 bytes to each of them on this second request
113
126
114
127
val manager = new ShuffleMemoryManager (1000L )
@@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
158
171
assert(state.t2Result2 === 250L )
159
172
}
160
173
161
- test(" threads can block to get at least 1 / 2N memory" ) {
174
+ test(" tasks can block to get at least 1 / 2N memory" ) {
162
175
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
163
176
// for a bit and releases 250 bytes, which should then be granted to t2. Further requests
164
177
// by t2 will return false right away because it now has 1 / 2N of the memory.
@@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
224
237
}
225
238
}
226
239
227
- test(" releaseMemoryForThisThread " ) {
240
+ test(" releaseMemoryForThisTask " ) {
228
241
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
229
242
// for a bit and releases all its memory. t2 should now be able to grab all the memory.
230
243
@@ -251,7 +264,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
251
264
}
252
265
}
253
266
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
254
- // sure the other thread blocks for some time otherwise
267
+ // sure the other task blocks for some time otherwise
255
268
Thread .sleep(300 )
256
269
manager.releaseMemoryForThisTask()
257
270
}
@@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
282
295
t2.join()
283
296
}
284
297
285
- // Both threads should've been able to acquire their memory; the second one will have waited
298
+ // Both tasks should've been able to acquire their memory; the second one will have waited
286
299
// until the first one acquired 1000 bytes and then released all of it
287
300
state.synchronized {
288
301
assert(state.t1Result === 1000L , " t1 could not allocate memory" )
@@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
293
306
}
294
307
}
295
308
296
- test(" threads should not be granted a negative size" ) {
309
+ test(" tasks should not be granted a negative size" ) {
297
310
val manager = new ShuffleMemoryManager (1000L )
298
311
manager.tryToAcquire(700L )
299
312
0 commit comments