Skip to content

Commit c5414b6

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3478] [PySpark] Profile the Python tasks
This patch add profiling support for PySpark, it will show the profiling results before the driver exits, here is one example: ``` ============================================================ Profile of RDD<id=3> ============================================================ 5146507 function calls (5146487 primitive calls) in 71.094 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 5144576 68.331 0.000 68.331 0.000 statcounter.py:44(merge) 20 2.735 0.137 71.071 3.554 statcounter.py:33(__init__) 20 0.017 0.001 0.017 0.001 {cPickle.dumps} 1024 0.003 0.000 0.003 0.000 t.py:16(<lambda>) 20 0.001 0.000 0.001 0.000 {reduce} 21 0.001 0.000 0.001 0.000 {cPickle.loads} 20 0.001 0.000 0.001 0.000 copy_reg.py:95(_slotnames) 41 0.001 0.000 0.001 0.000 serializers.py:461(read_int) 40 0.001 0.000 0.002 0.000 serializers.py:179(_batched) 62 0.000 0.000 0.000 0.000 {method 'read' of 'file' objects} 20 0.000 0.000 71.072 3.554 rdd.py:863(<lambda>) 20 0.000 0.000 0.001 0.000 serializers.py:198(load_stream) 40/20 0.000 0.000 71.072 3.554 rdd.py:2093(pipeline_func) 41 0.000 0.000 0.002 0.000 serializers.py:130(load_stream) 40 0.000 0.000 71.072 1.777 rdd.py:304(func) 20 0.000 0.000 71.094 3.555 worker.py:82(process) ``` Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk by `sc.dump_profiles(path)`, such as ```python >>> sc._conf.set("spark.python.profile", "true") >>> rdd = sc.parallelize(range(100)).map(str) >>> rdd.count() 100 >>> sc.show_profiles() ============================================================ Profile of RDD<id=1> ============================================================ 284 function calls (276 primitive calls) in 0.001 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 4 0.000 0.000 0.000 0.000 serializers.py:198(load_stream) 4 0.000 0.000 0.000 0.000 {reduce} 12/4 0.000 0.000 0.001 0.000 rdd.py:2092(pipeline_func) 4 0.000 0.000 0.000 0.000 {cPickle.loads} 4 0.000 0.000 0.000 0.000 {cPickle.dumps} 104 0.000 0.000 0.000 0.000 rdd.py:852(<genexpr>) 8 0.000 0.000 0.000 0.000 serializers.py:461(read_int) 12 0.000 0.000 0.000 0.000 rdd.py:303(func) ``` The profiling is disabled by default, can be enabled by "spark.python.profile=true". Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump" This is bugfix of apache#2351 cc JoshRosen Author: Davies Liu <[email protected]> Closes apache#2556 from davies/profiler and squashes the following commits: e68df5a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 858e74c [Davies Liu] compatitable with python 2.6 7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles() 2b0daf2 [Davies Liu] fix docs 7a56c24 [Davies Liu] bugfix cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 09d02c3 [Davies Liu] Merge branch 'master' into profiler c23865c [Davies Liu] Merge branch 'master' into profiler 15d6f18 [Davies Liu] add docs for two configs dadee1a [Davies Liu] add docs string and clear profiles after show or dump 4f8309d [Davies Liu] address comment, add tests 0a5b6eb [Davies Liu] fix Python UDF 4b20494 [Davies Liu] add profile for python
1 parent d75496b commit c5414b6

File tree

7 files changed

+127
-7
lines changed

7 files changed

+127
-7
lines changed

docs/configuration.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
206206
used during aggregation goes above this amount, it will spill the data into disks.
207207
</td>
208208
</tr>
209+
<tr>
210+
<td><code>spark.python.profile</code></td>
211+
<td>false</td>
212+
<td>
213+
Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
214+
or it will be displayed before the driver exiting. It also can be dumped into disk by
215+
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
216+
they will not be displayed automatically before driver exiting.
217+
</td>
218+
</tr>
219+
<tr>
220+
<td><code>spark.python.profile.dump</code></td>
221+
<td>(none)</td>
222+
<td>
223+
The directory which is used to dump the profile result before driver exiting.
224+
The results will be dumped as separated file for each RDD. They can be loaded
225+
by ptats.Stats(). If this is specified, the profile result will not be displayed
226+
automatically.
227+
</tr>
209228
<tr>
210229
<td><code>spark.python.worker.reuse</code></td>
211230
<td>true</td>

python/pyspark/accumulators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,21 @@ def addInPlace(self, value1, value2):
215215
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
216216

217217

218+
class PStatsParam(AccumulatorParam):
219+
"""PStatsParam is used to merge pstats.Stats"""
220+
221+
@staticmethod
222+
def zero(value):
223+
return None
224+
225+
@staticmethod
226+
def addInPlace(value1, value2):
227+
if value1 is None:
228+
return value2
229+
value1.add(value2)
230+
return value1
231+
232+
218233
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
219234

220235
"""

python/pyspark/context.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import sys
2121
from threading import Lock
2222
from tempfile import NamedTemporaryFile
23+
import atexit
2324

2425
from pyspark import accumulators
2526
from pyspark.accumulators import Accumulator
@@ -30,7 +31,6 @@
3031
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
3132
PairDeserializer, CompressedSerializer
3233
from pyspark.storagelevel import StorageLevel
33-
from pyspark import rdd
3434
from pyspark.rdd import RDD
3535
from pyspark.traceback_utils import CallSite, first_spark_call
3636

@@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
192192
self._temp_dir = \
193193
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
194194

195+
# profiling stats collected for each PythonRDD
196+
self._profile_stats = []
197+
195198
def _initialize_context(self, jconf):
196199
"""
197200
Initialize SparkContext in function to allow subclass specific initialization
@@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
792795
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
793796
return list(mappedRDD._collect_iterator_through_file(it))
794797

798+
def _add_profile(self, id, profileAcc):
799+
if not self._profile_stats:
800+
dump_path = self._conf.get("spark.python.profile.dump")
801+
if dump_path:
802+
atexit.register(self.dump_profiles, dump_path)
803+
else:
804+
atexit.register(self.show_profiles)
805+
806+
self._profile_stats.append([id, profileAcc, False])
807+
808+
def show_profiles(self):
809+
""" Print the profile stats to stdout """
810+
for i, (id, acc, showed) in enumerate(self._profile_stats):
811+
stats = acc.value
812+
if not showed and stats:
813+
print "=" * 60
814+
print "Profile of RDD<id=%d>" % id
815+
print "=" * 60
816+
stats.sort_stats("time", "cumulative").print_stats()
817+
# mark it as showed
818+
self._profile_stats[i][2] = True
819+
820+
def dump_profiles(self, path):
821+
""" Dump the profile stats into directory `path`
822+
"""
823+
if not os.path.exists(path):
824+
os.makedirs(path)
825+
for id, acc, _ in self._profile_stats:
826+
stats = acc.value
827+
if stats:
828+
p = os.path.join(path, "rdd_%d.pstats" % id)
829+
stats.dump_stats(p)
830+
self._profile_stats = []
831+
795832

796833
def _test():
797834
import atexit

python/pyspark/rdd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
from base64 import standard_b64encode as b64enc
1918
import copy
2019
from collections import defaultdict
2120
from itertools import chain, ifilter, imap
@@ -32,6 +31,7 @@
3231
from random import Random
3332
from math import sqrt, log, isinf, isnan
3433

34+
from pyspark.accumulators import PStatsParam
3535
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3636
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
3737
PickleSerializer, pack_long, AutoBatchedSerializer
@@ -2080,7 +2080,9 @@ def _jrdd(self):
20802080
return self._jrdd_val
20812081
if self._bypass_serializer:
20822082
self._jrdd_deserializer = NoOpSerializer()
2083-
command = (self.func, self._prev_jrdd_deserializer,
2083+
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
2084+
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
2085+
command = (self.func, profileStats, self._prev_jrdd_deserializer,
20842086
self._jrdd_deserializer)
20852087
# the serialized command will be compressed by broadcast
20862088
ser = CloudPickleSerializer()
@@ -2102,6 +2104,10 @@ def _jrdd(self):
21022104
self.ctx.pythonExec,
21032105
broadcast_vars, self.ctx._javaAccumulator)
21042106
self._jrdd_val = python_rdd.asJavaRDD()
2107+
2108+
if enable_profile:
2109+
self._id = self._jrdd_val.id()
2110+
self.ctx._add_profile(self._id, profileStats)
21052111
return self._jrdd_val
21062112

21072113
def id(self):

python/pyspark/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def registerFunction(self, name, f, returnType=StringType()):
960960
[Row(c0=4)]
961961
"""
962962
func = lambda _, it: imap(lambda x: f(*x), it)
963-
command = (func,
963+
command = (func, None,
964964
BatchedSerializer(PickleSerializer(), 1024),
965965
BatchedSerializer(PickleSerializer(), 1024))
966966
ser = CloudPickleSerializer()

python/pyspark/tests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,36 @@ def test_distinct(self):
632632
self.assertEquals(result.count(), 3)
633633

634634

635+
class TestProfiler(PySparkTestCase):
636+
637+
def setUp(self):
638+
self._old_sys_path = list(sys.path)
639+
class_name = self.__class__.__name__
640+
conf = SparkConf().set("spark.python.profile", "true")
641+
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
642+
643+
def test_profiler(self):
644+
645+
def heavy_foo(x):
646+
for i in range(1 << 20):
647+
x = 1
648+
rdd = self.sc.parallelize(range(100))
649+
rdd.foreach(heavy_foo)
650+
profiles = self.sc._profile_stats
651+
self.assertEqual(1, len(profiles))
652+
id, acc, _ = profiles[0]
653+
stats = acc.value
654+
self.assertTrue(stats is not None)
655+
width, stat_list = stats.get_print_list([])
656+
func_names = [func_name for fname, n, func_name in stat_list]
657+
self.assertTrue("heavy_foo" in func_names)
658+
659+
self.sc.show_profiles()
660+
d = tempfile.gettempdir()
661+
self.sc.dump_profiles(d)
662+
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
663+
664+
635665
class TestSQL(PySparkTestCase):
636666

637667
def setUp(self):

python/pyspark/worker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import time
2424
import socket
2525
import traceback
26+
import cProfile
27+
import pstats
2628

2729
from pyspark.accumulators import _accumulatorRegistry
2830
from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -90,10 +92,21 @@ def main(infile, outfile):
9092
command = pickleSer._read_with_length(infile)
9193
if isinstance(command, Broadcast):
9294
command = pickleSer.loads(command.value)
93-
(func, deserializer, serializer) = command
95+
(func, stats, deserializer, serializer) = command
9496
init_time = time.time()
95-
iterator = deserializer.load_stream(infile)
96-
serializer.dump_stream(func(split_index, iterator), outfile)
97+
98+
def process():
99+
iterator = deserializer.load_stream(infile)
100+
serializer.dump_stream(func(split_index, iterator), outfile)
101+
102+
if stats:
103+
p = cProfile.Profile()
104+
p.runcall(process)
105+
st = pstats.Stats(p)
106+
st.stream = None # make it picklable
107+
stats.add(st.strip_dirs())
108+
else:
109+
process()
97110
except Exception:
98111
try:
99112
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)

0 commit comments

Comments
 (0)