Skip to content

Commit 35e23ff

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-4172] [PySpark] Progress API in Python
This patch bring the pull based progress API into Python, also a example in Python. Author: Davies Liu <[email protected]> Closes #3027 from davies/progress_api and squashes the following commits: b1ba984 [Davies Liu] fix style d3b9253 [Davies Liu] add tests, mute the exception after stop 4297327 [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 969fa9d [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 25590c9 [Davies Liu] update with Java API 360de2d [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api c0f1021 [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 023afb3 [Davies Liu] add Python API and example for progress API (cherry picked from commit 445a755) Signed-off-by: Josh Rosen <[email protected]>
1 parent e65dc1f commit 35e23ff

File tree

6 files changed

+232
-24
lines changed

6 files changed

+232
-24
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.nio.ByteBuffer
21+
import java.util.concurrent.RejectedExecutionException
2122

2223
import scala.language.existentials
2324
import scala.util.control.NonFatal
@@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
9596
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
9697
serializedData: ByteBuffer) {
9798
var reason : TaskEndReason = UnknownReason
98-
getTaskResultExecutor.execute(new Runnable {
99-
override def run(): Unit = Utils.logUncaughtExceptions {
100-
try {
101-
if (serializedData != null && serializedData.limit() > 0) {
102-
reason = serializer.get().deserialize[TaskEndReason](
103-
serializedData, Utils.getSparkClassLoader)
99+
try {
100+
getTaskResultExecutor.execute(new Runnable {
101+
override def run(): Unit = Utils.logUncaughtExceptions {
102+
try {
103+
if (serializedData != null && serializedData.limit() > 0) {
104+
reason = serializer.get().deserialize[TaskEndReason](
105+
serializedData, Utils.getSparkClassLoader)
106+
}
107+
} catch {
108+
case cnd: ClassNotFoundException =>
109+
// Log an error but keep going here -- the task failed, so not catastrophic
110+
// if we can't deserialize the reason.
111+
val loader = Utils.getContextOrSparkClassLoader
112+
logError(
113+
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
114+
case ex: Exception => {}
104115
}
105-
} catch {
106-
case cnd: ClassNotFoundException =>
107-
// Log an error but keep going here -- the task failed, so not catastrophic if we can't
108-
// deserialize the reason.
109-
val loader = Utils.getContextOrSparkClassLoader
110-
logError(
111-
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
112-
case ex: Exception => {}
116+
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
113117
}
114-
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
115-
}
116-
})
118+
})
119+
} catch {
120+
case e: RejectedExecutionException if sparkEnv.isStopped =>
121+
// ignore it
122+
}
117123
}
118124

119125
def stop() {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import time
19+
import threading
20+
import Queue
21+
22+
from pyspark import SparkConf, SparkContext
23+
24+
25+
def delayed(seconds):
26+
def f(x):
27+
time.sleep(seconds)
28+
return x
29+
return f
30+
31+
32+
def call_in_background(f, *args):
33+
result = Queue.Queue(1)
34+
t = threading.Thread(target=lambda: result.put(f(*args)))
35+
t.daemon = True
36+
t.start()
37+
return result
38+
39+
40+
def main():
41+
conf = SparkConf().set("spark.ui.showConsoleProgress", "false")
42+
sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf)
43+
44+
def run():
45+
rdd = sc.parallelize(range(10), 10).map(delayed(2))
46+
reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
47+
return reduced.map(delayed(2)).collect()
48+
49+
result = call_in_background(run)
50+
status = sc.statusTracker()
51+
while result.empty():
52+
ids = status.getJobIdsForGroup()
53+
for id in ids:
54+
job = status.getJobInfo(id)
55+
print "Job", id, "status: ", job.status
56+
for sid in job.stageIds:
57+
info = status.getStageInfo(sid)
58+
if info:
59+
print "Stage %d: %d tasks total (%d active, %d complete)" % \
60+
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
61+
time.sleep(1)
62+
63+
print "Job results are:", result.get()
64+
sc.stop()
65+
66+
if __name__ == "__main__":
67+
main()

python/pyspark/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@
2222
2323
- :class:`SparkContext`:
2424
Main entry point for Spark functionality.
25-
- L{RDD}
25+
- :class:`RDD`:
2626
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
27-
- L{Broadcast}
27+
- :class:`Broadcast`:
2828
A broadcast variable that gets reused across tasks.
29-
- L{Accumulator}
29+
- :class:`Accumulator`:
3030
An "add-only" shared variable that tasks can only add values to.
31-
- L{SparkConf}
31+
- :class:`SparkConf`:
3232
For configuring Spark.
33-
- L{SparkFiles}
33+
- :class:`SparkFiles`:
3434
Access files shipped with jobs.
35-
- L{StorageLevel}
35+
- :class:`StorageLevel`:
3636
Finer-grained cache persistence levels.
3737
3838
"""
@@ -45,6 +45,7 @@
4545
from pyspark.accumulators import Accumulator, AccumulatorParam
4646
from pyspark.broadcast import Broadcast
4747
from pyspark.serializers import MarshalSerializer, PickleSerializer
48+
from pyspark.status import *
4849
from pyspark.profiler import Profiler, BasicProfiler
4950

5051
# for back compatibility
@@ -53,5 +54,5 @@
5354
__all__ = [
5455
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
5556
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
56-
"Profiler", "BasicProfiler",
57+
"StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
5758
]

python/pyspark/context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pyspark.storagelevel import StorageLevel
3333
from pyspark.rdd import RDD
3434
from pyspark.traceback_utils import CallSite, first_spark_call
35+
from pyspark.status import StatusTracker
3536
from pyspark.profiler import ProfilerCollector, BasicProfiler
3637

3738
from py4j.java_collections import ListConverter
@@ -810,6 +811,12 @@ def cancelAllJobs(self):
810811
"""
811812
self._jsc.sc().cancelAllJobs()
812813

814+
def statusTracker(self):
815+
"""
816+
Return :class:`StatusTracker` object
817+
"""
818+
return StatusTracker(self._jsc.statusTracker())
819+
813820
def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
814821
"""
815822
Executes the given partitionFunc on the specified set of partitions,

python/pyspark/status.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from collections import namedtuple
19+
20+
__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"]
21+
22+
23+
class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")):
24+
"""
25+
Exposes information about Spark Jobs.
26+
"""
27+
28+
29+
class SparkStageInfo(namedtuple("SparkStageInfo",
30+
"stageId currentAttemptId name numTasks numActiveTasks "
31+
"numCompletedTasks numFailedTasks")):
32+
"""
33+
Exposes information about Spark Stages.
34+
"""
35+
36+
37+
class StatusTracker(object):
38+
"""
39+
Low-level status reporting APIs for monitoring job and stage progress.
40+
41+
These APIs intentionally provide very weak consistency semantics;
42+
consumers of these APIs should be prepared to handle empty / missing
43+
information. For example, a job's stage ids may be known but the status
44+
API may not have any information about the details of those stages, so
45+
`getStageInfo` could potentially return `None` for a valid stage id.
46+
47+
To limit memory usage, these APIs only provide information on recent
48+
jobs / stages. These APIs will provide information for the last
49+
`spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs.
50+
"""
51+
def __init__(self, jtracker):
52+
self._jtracker = jtracker
53+
54+
def getJobIdsForGroup(self, jobGroup=None):
55+
"""
56+
Return a list of all known jobs in a particular job group. If
57+
`jobGroup` is None, then returns all known jobs that are not
58+
associated with a job group.
59+
60+
The returned list may contain running, failed, and completed jobs,
61+
and may vary across invocations of this method. This method does
62+
not guarantee the order of the elements in its result.
63+
"""
64+
return list(self._jtracker.getJobIdsForGroup(jobGroup))
65+
66+
def getActiveStageIds(self):
67+
"""
68+
Returns an array containing the ids of all active stages.
69+
"""
70+
return sorted(list(self._jtracker.getActiveStageIds()))
71+
72+
def getActiveJobsIds(self):
73+
"""
74+
Returns an array containing the ids of all active jobs.
75+
"""
76+
return sorted((list(self._jtracker.getActiveJobIds())))
77+
78+
def getJobInfo(self, jobId):
79+
"""
80+
Returns a :class:`SparkJobInfo` object, or None if the job info
81+
could not be found or was garbage collected.
82+
"""
83+
job = self._jtracker.getJobInfo(jobId)
84+
if job is not None:
85+
return SparkJobInfo(jobId, job.stageIds(), str(job.status()))
86+
87+
def getStageInfo(self, stageId):
88+
"""
89+
Returns a :class:`SparkStageInfo` object, or None if the stage
90+
info could not be found or was garbage collected.
91+
"""
92+
stage = self._jtracker.getStageInfo(stageId)
93+
if stage is not None:
94+
# TODO: fetch them in batch for better performance
95+
attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]]
96+
return SparkStageInfo(stageId, *attrs)

python/pyspark/tests.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,37 @@ def test_with_stop(self):
15501550
sc.stop()
15511551
self.assertEqual(SparkContext._active_spark_context, None)
15521552

1553+
def test_progress_api(self):
1554+
with SparkContext() as sc:
1555+
sc.setJobGroup('test_progress_api', '', True)
1556+
1557+
rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
1558+
t = threading.Thread(target=rdd.collect)
1559+
t.daemon = True
1560+
t.start()
1561+
# wait for scheduler to start
1562+
time.sleep(1)
1563+
1564+
tracker = sc.statusTracker()
1565+
jobIds = tracker.getJobIdsForGroup('test_progress_api')
1566+
self.assertEqual(1, len(jobIds))
1567+
job = tracker.getJobInfo(jobIds[0])
1568+
self.assertEqual(1, len(job.stageIds))
1569+
stage = tracker.getStageInfo(job.stageIds[0])
1570+
self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
1571+
1572+
sc.cancelAllJobs()
1573+
t.join()
1574+
# wait for event listener to update the status
1575+
time.sleep(1)
1576+
1577+
job = tracker.getJobInfo(jobIds[0])
1578+
self.assertEqual('FAILED', job.status)
1579+
self.assertEqual([], tracker.getActiveJobsIds())
1580+
self.assertEqual([], tracker.getActiveStageIds())
1581+
1582+
sc.stop()
1583+
15531584

15541585
@unittest.skipIf(not _have_scipy, "SciPy not installed")
15551586
class SciPyTests(PySparkTestCase):

0 commit comments

Comments
 (0)