@@ -19,7 +19,12 @@ package org.apache.spark.scheduler
19
19
20
20
import java .nio .ByteBuffer
21
21
22
- import org .scalatest .{BeforeAndAfter , BeforeAndAfterAll , FunSuite }
22
+ import scala .concurrent .duration ._
23
+ import scala .language .postfixOps
24
+ import scala .util .control .NonFatal
25
+
26
+ import org .scalatest .{BeforeAndAfter , FunSuite }
27
+ import org .scalatest .concurrent .Eventually ._
23
28
24
29
import org .apache .spark .{LocalSparkContext , SparkConf , SparkContext , SparkEnv }
25
30
import org .apache .spark .storage .TaskResultBlockId
@@ -34,6 +39,8 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
34
39
extends TaskResultGetter (sparkEnv, scheduler) {
35
40
var removedResult = false
36
41
42
+ @ volatile var removeBlockSuccessfully = false
43
+
37
44
override def enqueueSuccessfulTask (
38
45
taskSetManager : TaskSetManager , tid : Long , serializedData : ByteBuffer ) {
39
46
if (! removedResult) {
@@ -42,6 +49,15 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
42
49
serializer.get().deserialize[TaskResult [_]](serializedData) match {
43
50
case IndirectTaskResult (blockId, size) =>
44
51
sparkEnv.blockManager.master.removeBlock(blockId)
52
+ // removeBlock is asynchronous. Need to wait it's removed successfully
53
+ try {
54
+ eventually(timeout(3 seconds), interval(200 milliseconds)) {
55
+ assert(! sparkEnv.blockManager.master.contains(blockId))
56
+ }
57
+ removeBlockSuccessfully = true
58
+ } catch {
59
+ case NonFatal (e) => removeBlockSuccessfully = false
60
+ }
45
61
case directResult : DirectTaskResult [_] =>
46
62
taskSetManager.abort(" Internal error: expect only indirect results" )
47
63
}
@@ -92,10 +108,12 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark
92
108
assert(false , " Expect local cluster to use TaskSchedulerImpl" )
93
109
throw new ClassCastException
94
110
}
95
- scheduler.taskResultGetter = new ResultDeletingTaskResultGetter (sc.env, scheduler)
111
+ val resultGetter = new ResultDeletingTaskResultGetter (sc.env, scheduler)
112
+ scheduler.taskResultGetter = resultGetter
96
113
val akkaFrameSize =
97
114
sc.env.actorSystem.settings.config.getBytes(" akka.remote.netty.tcp.maximum-frame-size" ).toInt
98
115
val result = sc.parallelize(Seq (1 ), 1 ).map(x => 1 .to(akkaFrameSize).toArray).reduce((x, y) => x)
116
+ assert(resultGetter.removeBlockSuccessfully)
99
117
assert(result === 1 .to(akkaFrameSize).toArray)
100
118
101
119
// Make sure two tasks were run (one failed one, and a second retried one).
0 commit comments