Skip to content

Commit cd0629f

Browse files
committed
code refactoring and adding test
1 parent b073ee6 commit cd0629f

File tree

4 files changed

+66
-19
lines changed

4 files changed

+66
-19
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,14 +882,14 @@ class SparkContext(
882882
metadataCleaner.cancel()
883883
cleaner.foreach(_.stop())
884884
dagSchedulerCopy.stop()
885-
listenerBus.stop()
886-
eventLogger.foreach(_.stop())
887885
taskScheduler = null
888886
// TODO: Cache.stop()?
889887
env.stop()
890888
SparkEnv.set(null)
891889
ShuffleMapTask.clearCache()
892890
ResultTask.clearCache()
891+
listenerBus.stop()
892+
eventLogger.foreach(_.stop())
893893
logInfo("Successfully stopped SparkContext")
894894
} else {
895895
logInfo("SparkContext already stopped")

core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,19 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
3636
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
3737
private var queueFullErrorMessageLogged = false
3838
private var started = false
39-
private var sparkListenerBus: Option[Thread] = _
39+
private val listenerThread = new Thread("SparkListenerBus") {
40+
setDaemon(true)
41+
override def run() {
42+
while (true) {
43+
val event = eventQueue.take
44+
if (event == SparkListenerShutdown) {
45+
// Get out of the while loop and shutdown the daemon thread
46+
return
47+
}
48+
postToAll(event)
49+
}
50+
}
51+
}
4052

4153
/**
4254
* Start sending events to attached listeners.
@@ -49,21 +61,8 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
4961
if (started) {
5062
throw new IllegalStateException("Listener bus already started!")
5163
}
64+
listenerThread.start()
5265
started = true
53-
sparkListenerBus = Some(new Thread("SparkListenerBus") {
54-
setDaemon(true)
55-
override def run() {
56-
while (true) {
57-
val event = eventQueue.take
58-
if (event == SparkListenerShutdown) {
59-
// Get out of the while loop and shutdown the daemon thread
60-
return
61-
}
62-
postToAll(event)
63-
}
64-
}
65-
})
66-
sparkListenerBus.foreach(_.start())
6766
}
6867

6968
def post(event: SparkListenerEvent) {
@@ -99,6 +98,6 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
9998
throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")
10099
}
101100
post(SparkListenerShutdown)
102-
sparkListenerBus.foreach(_.join())
101+
listenerThread.join()
103102
}
104103
}

core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import java.util.concurrent.Semaphore
21+
2022
import scala.collection.mutable
2123

2224
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
@@ -72,6 +74,53 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
7274
}
7375
}
7476

77+
test("bus.stop() waits for the event queue to completely drain") {
78+
@volatile var drained = false
79+
80+
class BlockingListener(cond: AnyRef) extends SparkListener {
81+
override def onJobEnd(jobEnd: SparkListenerJobEnd) = {
82+
cond.synchronized { cond.wait() }
83+
drained = true
84+
}
85+
}
86+
87+
val bus = new LiveListenerBus
88+
val blockingListener = new BlockingListener(bus)
89+
val sem = new Semaphore(0)
90+
91+
bus.addListener(blockingListener)
92+
bus.post(SparkListenerJobEnd(0, JobSucceeded))
93+
bus.start()
94+
// the queue should not drain immediately
95+
assert(!drained)
96+
97+
new Thread("ListenerBusStopper") {
98+
override def run() {
99+
// stop() would block until notify() is called below
100+
bus.stop()
101+
sem.release()
102+
}
103+
}.start()
104+
105+
val startTime = System.currentTimeMillis()
106+
val waitTime = 100
107+
var done = false
108+
while (!done) {
109+
if (System.currentTimeMillis() > startTime + waitTime) {
110+
bus.synchronized {
111+
bus.notify()
112+
}
113+
done = true
114+
} else {
115+
Thread.sleep(10)
116+
// bus.stop() should wait until the event queue is drained
117+
assert(!drained)
118+
}
119+
}
120+
sem.acquire()
121+
assert(drained)
122+
}
123+
75124
test("basic creation of StageInfo") {
76125
val listener = new SaveStageAndTaskInfo
77126
sc.addSparkListener(listener)

examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,5 @@ object SparkHdfsLR {
7474

7575
println("Final w: " + w)
7676
sc.stop()
77-
System.exit(0)
7877
}
7978
}

0 commit comments

Comments
 (0)