Skip to content

Commit 14174ab

Browse files
daviesmateiz
authored andcommitted
[SPARK-2538] [PySpark] Hash based disk spilling aggregation
During aggregation in Python worker, if the memory usage is above spark.executor.memory, it will do disk spilling aggregation. It will split the aggregation into multiple stage, in each stage, it will partition the aggregated data by hash and dump them into disks. After all the data are aggregated, it will merge all the stages together (partition by partition). Author: Davies Liu <[email protected]> Closes apache#1460 from davies/spill and squashes the following commits: cad91bf [Davies Liu] call gc.collect() after data.clear() to release memory as much as possible. 37d71f7 [Davies Liu] balance the partitions 902f036 [Davies Liu] add shuffle.py into run-tests dcf03a9 [Davies Liu] fix memory_info() of psutil 67e6eba [Davies Liu] comment for MAX_TOTAL_PARTITIONS f6bd5d6 [Davies Liu] rollback next_limit() again, the performance difference is huge: e74b785 [Davies Liu] fix code style and change next_limit to memory_limit 400be01 [Davies Liu] address all the comments 6178844 [Davies Liu] refactor and improve docs fdd0a49 [Davies Liu] add long doc string for ExternalMerger 1a97ce4 [Davies Liu] limit used memory and size of objects in partitionBy() e6cc7f9 [Davies Liu] Merge branch 'master' into spill 3652583 [Davies Liu] address comments e78a0a0 [Davies Liu] fix style 24cec6a [Davies Liu] get local directory by SPARK_LOCAL_DIR 57ee7ef [Davies Liu] update docs 286aaff [Davies Liu] let spilled aggregation in Python configurable e9a40f6 [Davies Liu] recursive merger 6edbd1f [Davies Liu] Hash based disk spilling aggregation
1 parent eff9714 commit 14174ab

File tree

9 files changed

+611
-25
lines changed

9 files changed

+611
-25
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ private[spark] class PythonRDD[T: ClassTag](
5757
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
5858
val startTime = System.currentTimeMillis
5959
val env = SparkEnv.get
60-
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
60+
val localdir = env.blockManager.diskBlockManager.localDirs.map(
61+
f => f.getPath()).mkString(",")
62+
val worker: Socket = env.createPythonWorker(pythonExec,
63+
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
6164

6265
// Start a thread to feed the process input from our parent's iterator
6366
val writerThread = new WriterThread(env, worker, split, context)

core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
4343
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
4444
* directory, create multiple subdirectories that we will hash files into, in order to avoid
4545
* having really large inodes at the top level. */
46-
private val localDirs: Array[File] = createLocalDirs()
46+
val localDirs: Array[File] = createLocalDirs()
4747
if (localDirs.isEmpty) {
4848
logError("Failed to create any local dir.")
4949
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)

docs/configuration.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful
197197
Spark's dependencies and user dependencies. It is currently an experimental feature.
198198
</td>
199199
</tr>
200+
<tr>
201+
<td><code>spark.python.worker.memory</code></td>
202+
<td>512m</td>
203+
<td>
204+
Amount of memory to use per python worker process during aggregation, in the same
205+
format as JVM memory strings (e.g. <code>512m</code>, <code>2g</code>). If the memory
206+
used during aggregation goes above this amount, it will spill the data into disks.
207+
</td>
208+
</tr>
200209
</table>
201210

202211
#### Shuffle Behavior

python/epydoc.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ private: no
3535
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
3636
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests
3737
pyspark.rddsampler pyspark.daemon pyspark.mllib._common
38-
pyspark.mllib.tests
38+
pyspark.mllib.tests pyspark.shuffle

python/pyspark/rdd.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from pyspark.rddsampler import RDDSampler
4343
from pyspark.storagelevel import StorageLevel
4444
from pyspark.resultiterable import ResultIterable
45+
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
46+
get_used_memory
4547

4648
from py4j.java_collections import ListConverter, MapConverter
4749

@@ -197,6 +199,22 @@ def _replaceRoot(self, value):
197199
self._sink(1)
198200

199201

202+
def _parse_memory(s):
203+
"""
204+
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
205+
return the value in MB
206+
207+
>>> _parse_memory("256m")
208+
256
209+
>>> _parse_memory("2g")
210+
2048
211+
"""
212+
units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
213+
if s[-1] not in units:
214+
raise ValueError("invalid format: " + s)
215+
return int(float(s[:-1]) * units[s[-1].lower()])
216+
217+
200218
class RDD(object):
201219

202220
"""
@@ -1207,20 +1225,49 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
12071225
if numPartitions is None:
12081226
numPartitions = self._defaultReducePartitions()
12091227

1210-
# Transferring O(n) objects to Java is too expensive. Instead, we'll
1211-
# form the hash buckets in Python, transferring O(numPartitions) objects
1212-
# to Java. Each object is a (splitNumber, [objects]) pair.
1228+
# Transferring O(n) objects to Java is too expensive.
1229+
# Instead, we'll form the hash buckets in Python,
1230+
# transferring O(numPartitions) objects to Java.
1231+
# Each object is a (splitNumber, [objects]) pair.
1232+
# In order to avoid too huge objects, the objects are
1233+
# grouped into chunks.
12131234
outputSerializer = self.ctx._unbatched_serializer
12141235

1236+
limit = (_parse_memory(self.ctx._conf.get(
1237+
"spark.python.worker.memory", "512m")) / 2)
1238+
12151239
def add_shuffle_key(split, iterator):
12161240

12171241
buckets = defaultdict(list)
1242+
c, batch = 0, min(10 * numPartitions, 1000)
12181243

12191244
for (k, v) in iterator:
12201245
buckets[partitionFunc(k) % numPartitions].append((k, v))
1246+
c += 1
1247+
1248+
# check used memory and avg size of chunk of objects
1249+
if (c % 1000 == 0 and get_used_memory() > limit
1250+
or c > batch):
1251+
n, size = len(buckets), 0
1252+
for split in buckets.keys():
1253+
yield pack_long(split)
1254+
d = outputSerializer.dumps(buckets[split])
1255+
del buckets[split]
1256+
yield d
1257+
size += len(d)
1258+
1259+
avg = (size / n) >> 20
1260+
# let 1M < avg < 10M
1261+
if avg < 1:
1262+
batch *= 1.5
1263+
elif avg > 10:
1264+
batch = max(batch / 1.5, 1)
1265+
c = 0
1266+
12211267
for (split, items) in buckets.iteritems():
12221268
yield pack_long(split)
12231269
yield outputSerializer.dumps(items)
1270+
12241271
keyed = PipelinedRDD(self, add_shuffle_key)
12251272
keyed._bypass_serializer = True
12261273
with _JavaStackTrace(self.context) as st:
@@ -1230,8 +1277,8 @@ def add_shuffle_key(split, iterator):
12301277
id(partitionFunc))
12311278
jrdd = pairRDD.partitionBy(partitioner).values()
12321279
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
1233-
# This is required so that id(partitionFunc) remains unique, even if
1234-
# partitionFunc is a lambda:
1280+
# This is required so that id(partitionFunc) remains unique,
1281+
# even if partitionFunc is a lambda:
12351282
rdd._partitionFunc = partitionFunc
12361283
return rdd
12371284

@@ -1265,26 +1312,28 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
12651312
if numPartitions is None:
12661313
numPartitions = self._defaultReducePartitions()
12671314

1315+
serializer = self.ctx.serializer
1316+
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
1317+
== 'true')
1318+
memory = _parse_memory(self.ctx._conf.get(
1319+
"spark.python.worker.memory", "512m"))
1320+
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
1321+
12681322
def combineLocally(iterator):
1269-
combiners = {}
1270-
for x in iterator:
1271-
(k, v) = x
1272-
if k not in combiners:
1273-
combiners[k] = createCombiner(v)
1274-
else:
1275-
combiners[k] = mergeValue(combiners[k], v)
1276-
return combiners.iteritems()
1323+
merger = ExternalMerger(agg, memory * 0.9, serializer) \
1324+
if spill else InMemoryMerger(agg)
1325+
merger.mergeValues(iterator)
1326+
return merger.iteritems()
1327+
12771328
locally_combined = self.mapPartitions(combineLocally)
12781329
shuffled = locally_combined.partitionBy(numPartitions)
12791330

12801331
def _mergeCombiners(iterator):
1281-
combiners = {}
1282-
for (k, v) in iterator:
1283-
if k not in combiners:
1284-
combiners[k] = v
1285-
else:
1286-
combiners[k] = mergeCombiners(combiners[k], v)
1287-
return combiners.iteritems()
1332+
merger = ExternalMerger(agg, memory, serializer) \
1333+
if spill else InMemoryMerger(agg)
1334+
merger.mergeCombiners(iterator)
1335+
return merger.iteritems()
1336+
12881337
return shuffled.mapPartitions(_mergeCombiners)
12891338

12901339
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
@@ -1343,7 +1392,8 @@ def mergeValue(xs, x):
13431392
return xs
13441393

13451394
def mergeCombiners(a, b):
1346-
return a + b
1395+
a.extend(b)
1396+
return a
13471397

13481398
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
13491399
numPartitions).mapValues(lambda x: ResultIterable(x))

python/pyspark/serializers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def load_stream(self, stream):
193193
return chain.from_iterable(self._load_stream_without_unbatching(stream))
194194

195195
def _load_stream_without_unbatching(self, stream):
196-
return self.serializer.load_stream(stream)
196+
return self.serializer.load_stream(stream)
197197

198198
def __eq__(self, other):
199199
return (isinstance(other, BatchedSerializer) and
@@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer):
302302
loads = marshal.loads
303303

304304

305+
class AutoSerializer(FramedSerializer):
306+
"""
307+
Choose marshal or cPickle as serialization protocol autumatically
308+
"""
309+
def __init__(self):
310+
FramedSerializer.__init__(self)
311+
self._type = None
312+
313+
def dumps(self, obj):
314+
if self._type is not None:
315+
return 'P' + cPickle.dumps(obj, -1)
316+
try:
317+
return 'M' + marshal.dumps(obj)
318+
except Exception:
319+
self._type = 'P'
320+
return 'P' + cPickle.dumps(obj, -1)
321+
322+
def loads(self, obj):
323+
_type = obj[0]
324+
if _type == 'M':
325+
return marshal.loads(obj[1:])
326+
elif _type == 'P':
327+
return cPickle.loads(obj[1:])
328+
else:
329+
raise ValueError("invalid sevialization type: %s" % _type)
330+
331+
305332
class UTF8Deserializer(Serializer):
306333
"""
307334
Deserializes streams written by String.getBytes.

0 commit comments

Comments
 (0)