17
17
18
18
package org .apache .spark .scheduler
19
19
20
- import java .io .{NotSerializableException , PrintWriter , StringWriter }
20
+ import java .io .{NotSerializableException }
21
21
import java .util .Properties
22
22
import java .util .concurrent .atomic .AtomicInteger
23
23
@@ -35,6 +35,7 @@ import akka.pattern.ask
35
35
import akka .util .Timeout
36
36
37
37
import org .apache .spark ._
38
+ import org .apache .spark .broadcast .Broadcast
38
39
import org .apache .spark .executor .TaskMetrics
39
40
import org .apache .spark .partial .{ApproximateActionListener , ApproximateEvaluator , PartialResult }
40
41
import org .apache .spark .rdd .RDD
@@ -694,7 +695,21 @@ class DAGScheduler(
694
695
// Get our pending tasks and remember them in our pendingTasks entry
695
696
stage.pendingTasks.clear()
696
697
var tasks = ArrayBuffer [Task [_]]()
697
- val broadcastRddBinary = stage.rdd.createBroadcastBinary()
698
+
699
+ var broadcastRddBinary : Broadcast [Array [Byte ]] = null
700
+ try {
701
+ broadcastRddBinary = stage.rdd.createBroadcastBinary()
702
+ } catch {
703
+ case e : NotSerializableException =>
704
+ abortStage(stage, " Task not serializable: " + e.toString)
705
+ runningStages -= stage
706
+ return
707
+ case NonFatal (e) =>
708
+ abortStage(stage, s " Task serialization failed: $e\n ${e.getStackTraceString}" )
709
+ runningStages -= stage
710
+ return
711
+ }
712
+
698
713
if (stage.isShuffleMap) {
699
714
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil ) {
700
715
val locs = getPreferredLocs(stage.rdd, p)
0 commit comments