@@ -96,6 +96,30 @@ class DAGScheduler(
96
96
// Stages that must be resubmitted due to fetch failures
97
97
private [scheduler] val failedStages = new HashSet [Stage ]
98
98
99
+ // The maximum number of times to retry a stage before aborting
100
+ val maxStageFailures = 5
101
+
102
+ // To avoid cyclical stage failures (see SPARK-5945) we limit the number of times that a stage
103
+ // may be retried. However, it only makes sense to limit the number of times that a stage fails
104
+ // if it's failing for the same reason every time. Therefore, track why a stage fails as well as
105
+ // how many times it has failed.
106
+ case class StageFailure (failureReason : String ) {
107
+ var count = 1
108
+ def fail () = { count += 1 }
109
+ def shouldAbort (): Boolean = { count >= maxStageFailures }
110
+
111
+ override def equals (other : Any ): Boolean =
112
+ other match {
113
+ case that : StageFailure => that.failureReason.equals(this .failureReason)
114
+ case _ => false
115
+ }
116
+
117
+ override def hashCode : Int = failureReason.hashCode()
118
+ }
119
+
120
+ // Map to track failure reasons for a given stage (indexed by stage ID)
121
+ private [scheduler] val stageFailureReasons = new HashMap [Stage , HashSet [StageFailure ]]
122
+
99
123
private [scheduler] val activeJobs = new HashSet [ActiveJob ]
100
124
101
125
/**
@@ -460,6 +484,10 @@ class DAGScheduler(
460
484
logDebug(" Removing stage %d from failed set." .format(stageId))
461
485
failedStages -= stage
462
486
}
487
+ if (stageFailureReasons.contains(stage)) {
488
+ logDebug(" Removing stage %d from failure reasons set." .format(stageId))
489
+ stageFailureReasons -= stage
490
+ }
463
491
}
464
492
// data structures based on StageId
465
493
stageIdToStage -= stageId
@@ -940,6 +968,29 @@ class DAGScheduler(
940
968
}
941
969
}
942
970
971
+ /**
972
+ * Check whether we should abort the failedStage due to multiple failures for the same reason.
973
+ * This method updates the running count of failures for a particular stage and returns
974
+ * true if the number of failures for any single reason exceeds the allowable number
975
+ * of failures.
976
+ * @return An Option that contains the failure reason that caused the abort
977
+ */
978
+ def shouldAbortStage (failedStage : Stage , failureReason : String ): Option [String ] = {
979
+ if (! stageFailureReasons.contains(failedStage))
980
+ stageFailureReasons.put(failedStage, new HashSet [StageFailure ]())
981
+
982
+ val failures = stageFailureReasons.get(failedStage).get
983
+ val failure = StageFailure (failureReason)
984
+ failures.find(s => s.equals(failure)) match {
985
+ case Some (f) => f.fail()
986
+ case None => failures.add(failure)
987
+ }
988
+ failures.find(_.shouldAbort()) match {
989
+ case Some (f) => Some (f.failureReason)
990
+ case None => None
991
+ }
992
+ }
993
+
943
994
/**
944
995
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
945
996
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -1083,8 +1134,13 @@ class DAGScheduler(
1083
1134
markStageAsFinished(failedStage, Some (failureMessage))
1084
1135
}
1085
1136
1137
+ val shouldAbort = shouldAbortStage(failedStage, failureMessage)
1086
1138
if (disallowStageRetryForTest) {
1087
1139
abortStage(failedStage, " Fetch failure will not retry stage due to testing config" )
1140
+ } else if (shouldAbort.isDefined) {
1141
+ abortStage(failedStage, s " Fetch failure - aborting stage. Stage ${failedStage.name} " +
1142
+ s " has failed the maximum allowable number of times: ${maxStageFailures}. " +
1143
+ s " Failure reason: ${shouldAbort.get}" )
1088
1144
} else if (failedStages.isEmpty) {
1089
1145
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1090
1146
// in that case the event will already have been scheduled.
0 commit comments