Skip to content

Commit 09d02c3

Browse files
committed
Merge branch 'master' into profiler
Conflicts: docs/configuration.md
2 parents c23865c + 2aea0da commit 09d02c3

File tree

44 files changed

+924
-496
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+924
-496
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ class SparkEnv (
108108
pythonWorkers.get(key).foreach(_.stopWorker(worker))
109109
}
110110
}
111+
112+
private[spark]
113+
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
114+
synchronized {
115+
val key = (pythonExec, envVars)
116+
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
117+
}
118+
}
111119
}
112120

113121
object SparkEnv extends Logging {

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.executor.TaskMetrics
24-
import org.apache.spark.util.TaskCompletionListener
24+
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
2525

2626

2727
/**
@@ -41,7 +41,7 @@ class TaskContext(
4141
val attemptId: Long,
4242
val runningLocally: Boolean = false,
4343
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
44-
extends Serializable {
44+
extends Serializable with Logging {
4545

4646
@deprecated("use partitionId", "0.8.1")
4747
def splitId = partitionId
@@ -103,8 +103,20 @@ class TaskContext(
103103
/** Marks the task as completed and triggers the listeners. */
104104
private[spark] def markTaskCompleted(): Unit = {
105105
completed = true
106+
val errorMsgs = new ArrayBuffer[String](2)
106107
// Process complete callbacks in the reverse order of registration
107-
onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
108+
onCompleteCallbacks.reverse.foreach { listener =>
109+
try {
110+
listener.onTaskCompletion(this)
111+
} catch {
112+
case e: Throwable =>
113+
errorMsgs += e.getMessage
114+
logError("Error in TaskCompletionListener", e)
115+
}
116+
}
117+
if (errorMsgs.nonEmpty) {
118+
throw new TaskCompletionListenerException(errorMsgs)
119+
}
108120
}
109121

110122
/** Marks the task for interruption, i.e. cancellation. */

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.api.java
1919

20+
import java.io.Closeable
2021
import java.util
2122
import java.util.{Map => JMap}
2223

@@ -40,7 +41,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
4041
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
4142
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
4243
*/
43-
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
44+
class JavaSparkContext(val sc: SparkContext)
45+
extends JavaSparkContextVarargsWorkaround with Closeable {
46+
4447
/**
4548
* Create a JavaSparkContext that loads settings from system properties (for instance, when
4649
* launching with ./bin/spark-submit).
@@ -534,6 +537,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
534537
sc.stop()
535538
}
536539

540+
override def close(): Unit = stop()
541+
537542
/**
538543
* Get Spark's home location from either a value set through the constructor,
539544
* or the spark.home Java property, or the SPARK_HOME environment variable

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
2424

2525
import scala.collection.JavaConversions._
26+
import scala.collection.mutable
2627
import scala.language.existentials
2728
import scala.reflect.ClassTag
2829
import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
5253
extends RDD[Array[Byte]](parent) {
5354

5455
val bufferSize = conf.getInt("spark.buffer.size", 65536)
56+
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
5557

5658
override def getPartitions = parent.partitions
5759

@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
6365
val localdir = env.blockManager.diskBlockManager.localDirs.map(
6466
f => f.getPath()).mkString(",")
6567
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
68+
if (reuse_worker) {
69+
envVars += ("SPARK_REUSE_WORKER" -> "1")
70+
}
6671
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6772

6873
// Start a thread to feed the process input from our parent's iterator
6974
val writerThread = new WriterThread(env, worker, split, context)
7075

76+
var complete_cleanly = false
7177
context.addTaskCompletionListener { context =>
7278
writerThread.shutdownOnTaskCompletion()
73-
74-
// Cleanup the worker socket. This will also cause the Python worker to exit.
75-
try {
76-
worker.close()
77-
} catch {
78-
case e: Exception => logWarning("Failed to close worker socket", e)
79+
if (reuse_worker && complete_cleanly) {
80+
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
81+
} else {
82+
try {
83+
worker.close()
84+
} catch {
85+
case e: Exception =>
86+
logWarning("Failed to close worker socket", e)
87+
}
7988
}
8089
}
8190

@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
133142
stream.readFully(update)
134143
accumulator += Collections.singletonList(update)
135144
}
145+
complete_cleanly = true
136146
null
137147
}
138148
} catch {
@@ -195,29 +205,45 @@ private[spark] class PythonRDD(
195205
PythonRDD.writeUTF(include, dataOut)
196206
}
197207
// Broadcast variables
198-
dataOut.writeInt(broadcastVars.length)
208+
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
209+
val newBids = broadcastVars.map(_.id).toSet
210+
// number of different broadcasts
211+
val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
212+
dataOut.writeInt(cnt)
213+
for (bid <- oldBids) {
214+
if (!newBids.contains(bid)) {
215+
// remove the broadcast from worker
216+
dataOut.writeLong(- bid - 1) // bid >= 0
217+
oldBids.remove(bid)
218+
}
219+
}
199220
for (broadcast <- broadcastVars) {
200-
dataOut.writeLong(broadcast.id)
201-
dataOut.writeInt(broadcast.value.length)
202-
dataOut.write(broadcast.value)
221+
if (!oldBids.contains(broadcast.id)) {
222+
// send new broadcast
223+
dataOut.writeLong(broadcast.id)
224+
dataOut.writeInt(broadcast.value.length)
225+
dataOut.write(broadcast.value)
226+
oldBids.add(broadcast.id)
227+
}
203228
}
204229
dataOut.flush()
205230
// Serialized command:
206231
dataOut.writeInt(command.length)
207232
dataOut.write(command)
208233
// Data values
209234
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
235+
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
210236
dataOut.flush()
211237
} catch {
212238
case e: Exception if context.isCompleted || context.isInterrupted =>
213239
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
240+
worker.shutdownOutput()
214241

215242
case e: Exception =>
216243
// We must avoid throwing exceptions here, because the thread uncaught exception handler
217244
// will kill the whole executor (see org.apache.spark.executor.Executor).
218245
_exception = e
219-
} finally {
220-
Try(worker.shutdownOutput()) // kill Python worker process
246+
worker.shutdownOutput()
221247
}
222248
}
223249
}
@@ -278,6 +304,14 @@ private object SpecialLengths {
278304
private[spark] object PythonRDD extends Logging {
279305
val UTF8 = Charset.forName("UTF-8")
280306

307+
// remember the broadcasts sent to each worker
308+
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
309+
private def getWorkerBroadcasts(worker: Socket) = {
310+
synchronized {
311+
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
312+
}
313+
}
314+
281315
/**
282316
* Adapter for calling SparkContext#runJob from Python.
283317
*

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
4040
var daemon: Process = null
4141
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
4242
var daemonPort: Int = 0
43-
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
43+
val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
44+
val idleWorkers = new mutable.Queue[Socket]()
45+
var lastActivity = 0L
46+
new MonitorThread().start()
4447

4548
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
4649

@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
5154

5255
def create(): Socket = {
5356
if (useDaemon) {
57+
synchronized {
58+
if (idleWorkers.size > 0) {
59+
return idleWorkers.dequeue()
60+
}
61+
}
5462
createThroughDaemon()
5563
} else {
5664
createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
199207
}
200208
}
201209

210+
/**
211+
* Monitor all the idle workers, kill them after timeout.
212+
*/
213+
private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
214+
215+
setDaemon(true)
216+
217+
override def run() {
218+
while (true) {
219+
synchronized {
220+
if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
221+
cleanupIdleWorkers()
222+
lastActivity = System.currentTimeMillis()
223+
}
224+
}
225+
Thread.sleep(10000)
226+
}
227+
}
228+
}
229+
230+
private def cleanupIdleWorkers() {
231+
while (idleWorkers.length > 0) {
232+
val worker = idleWorkers.dequeue()
233+
try {
234+
// the worker will exit after closing the socket
235+
worker.close()
236+
} catch {
237+
case e: Exception =>
238+
logWarning("Failed to close worker socket", e)
239+
}
240+
}
241+
}
242+
202243
private def stopDaemon() {
203244
synchronized {
204245
if (useDaemon) {
246+
cleanupIdleWorkers()
247+
205248
// Request shutdown of existing daemon by sending SIGTERM
206249
if (daemon != null) {
207250
daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
220263
}
221264

222265
def stopWorker(worker: Socket) {
223-
if (useDaemon) {
224-
if (daemon != null) {
225-
daemonWorkers.get(worker).foreach { pid =>
226-
// tell daemon to kill worker by pid
227-
val output = new DataOutputStream(daemon.getOutputStream)
228-
output.writeInt(pid)
229-
output.flush()
230-
daemon.getOutputStream.flush()
266+
synchronized {
267+
if (useDaemon) {
268+
if (daemon != null) {
269+
daemonWorkers.get(worker).foreach { pid =>
270+
// tell daemon to kill worker by pid
271+
val output = new DataOutputStream(daemon.getOutputStream)
272+
output.writeInt(pid)
273+
output.flush()
274+
daemon.getOutputStream.flush()
275+
}
231276
}
277+
} else {
278+
simpleWorkers.get(worker).foreach(_.destroy())
232279
}
233-
} else {
234-
simpleWorkers.get(worker).foreach(_.destroy())
235280
}
236281
worker.close()
237282
}
283+
284+
def releaseWorker(worker: Socket) {
285+
if (useDaemon) {
286+
synchronized {
287+
lastActivity = System.currentTimeMillis()
288+
idleWorkers.enqueue(worker)
289+
}
290+
} else {
291+
// Cleanup the worker socket. This will also cause the Python worker to exit.
292+
try {
293+
worker.close()
294+
} catch {
295+
case e: Exception =>
296+
logWarning("Failed to close worker socket", e)
297+
}
298+
}
299+
}
238300
}
239301

240302
private object PythonWorkerFactory {
241303
val PROCESS_WAIT_TIMEOUT_MS = 10000
304+
val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
242305
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util
19+
20+
/**
21+
* Exception thrown when there is an exception in
22+
* executing the callback in TaskCompletionListener.
23+
*/
24+
private[spark]
25+
class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {
26+
27+
override def getMessage: String = {
28+
if (errorMessages.size == 1) {
29+
errorMessages.head
30+
} else {
31+
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)