Skip to content

Commit 7b0f04b

Browse files
committed
Fix ShuffleMemoryManagerSuite
1 parent f57f3f2 commit 7b0f04b

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,39 @@
1717

1818
package org.apache.spark.shuffle
1919

20+
import java.util.concurrent.CountDownLatch
21+
import java.util.concurrent.atomic.AtomicInteger
22+
23+
import org.mockito.Mockito._
2024
import org.scalatest.concurrent.Timeouts
2125
import org.scalatest.time.SpanSugar._
22-
import java.util.concurrent.atomic.AtomicBoolean
23-
import java.util.concurrent.CountDownLatch
2426

25-
import org.apache.spark.SparkFunSuite
27+
import org.apache.spark.{SparkFunSuite, TaskContext}
2628

2729
class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
30+
31+
val nextTaskAttemptId = new AtomicInteger()
32+
2833
/** Launch a thread with the given body block and return it. */
2934
private def startThread(name: String)(body: => Unit): Thread = {
3035
val thread = new Thread("ShuffleMemorySuite " + name) {
3136
override def run() {
32-
body
37+
try {
38+
val taskAttemptId = nextTaskAttemptId.getAndIncrement
39+
val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
40+
when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
41+
TaskContext.setTaskContext(mockTaskContext)
42+
body
43+
} finally {
44+
TaskContext.unset()
45+
}
3346
}
3447
}
3548
thread.start()
3649
thread
3750
}
3851

39-
test("single thread requesting memory") {
52+
test("single task requesting memory") {
4053
val manager = new ShuffleMemoryManager(1000L)
4154

4255
assert(manager.tryToAcquire(100L) === 100L)
@@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
107120
}
108121

109122

110-
test("threads cannot grow past 1 / N") {
111-
// Two threads request 250 bytes first, wait for each other to get it, and then request
123+
test("tasks cannot grow past 1 / N") {
124+
// Two tasks request 250 bytes first, wait for each other to get it, and then request
112125
// 500 more; we should only grant 250 bytes to each of them on this second request
113126

114127
val manager = new ShuffleMemoryManager(1000L)
@@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
158171
assert(state.t2Result2 === 250L)
159172
}
160173

161-
test("threads can block to get at least 1 / 2N memory") {
174+
test("tasks can block to get at least 1 / 2N memory") {
162175
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
163176
// for a bit and releases 250 bytes, which should then be granted to t2. Further requests
164177
// by t2 will return false right away because it now has 1 / 2N of the memory.
@@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
224237
}
225238
}
226239

227-
test("releaseMemoryForThisThread") {
240+
test("releaseMemoryForThisTask") {
228241
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
229242
// for a bit and releases all its memory. t2 should now be able to grab all the memory.
230243

@@ -251,7 +264,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
251264
}
252265
}
253266
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
254-
// sure the other thread blocks for some time otherwise
267+
// sure the other task blocks for some time otherwise
255268
Thread.sleep(300)
256269
manager.releaseMemoryForThisTask()
257270
}
@@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
282295
t2.join()
283296
}
284297

285-
// Both threads should've been able to acquire their memory; the second one will have waited
298+
// Both tasks should've been able to acquire their memory; the second one will have waited
286299
// until the first one acquired 1000 bytes and then released all of it
287300
state.synchronized {
288301
assert(state.t1Result === 1000L, "t1 could not allocate memory")
@@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
293306
}
294307
}
295308

296-
test("threads should not be granted a negative size") {
309+
test("tasks should not be granted a negative size") {
297310
val manager = new ShuffleMemoryManager(1000L)
298311
manager.tryToAcquire(700L)
299312

0 commit comments

Comments
 (0)