Skip to content

Commit 1a97ce4

Browse files
committed
limit used memory and size of objects in partitionBy()
1 parent e6cc7f9 commit 1a97ce4

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

python/pyspark/rdd.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,23 +1227,39 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
12271227

12281228
# Transferring O(n) objects to Java is too expensive. Instead, we'll
12291229
# form the hash buckets in Python, transferring O(numPartitions) objects
1230-
# to Java. Each object is a (splitNumber, [objects]) pair.
1230+
# to Java. Each object is a (splitNumber, [objects]) pair.
1231+
# In order to void too huge objects, the objects are grouped into chunks.
12311232
outputSerializer = self.ctx._unbatched_serializer
12321233

1233-
limit = _parse_memory(self.ctx._conf.get("spark.python.worker.memory")
1234-
or "512m")
1234+
limit = (_parse_memory(self.ctx._conf.get("spark.python.worker.memory")
1235+
or "512m") / 2)
12351236
def add_shuffle_key(split, iterator):
12361237

12371238
buckets = defaultdict(list)
1238-
c, batch = 0, 1000
1239+
c, batch = 0, min(10 * numPartitions, 1000)
1240+
12391241
for (k, v) in iterator:
12401242
buckets[partitionFunc(k) % numPartitions].append((k, v))
12411243
c += 1
1242-
if c % batch == 0 and get_used_memory() > limit:
1244+
1245+
# check used memory and avg size of chunk of objects
1246+
if (c % 1000 == 0 and get_used_memory() > limit
1247+
or c > batch):
1248+
n, size = len(buckets), 0
12431249
for split in buckets.keys():
12441250
yield pack_long(split)
1245-
yield outputSerializer.dumps(buckets[split])
1251+
d = outputSerializer.dumps(buckets[split])
12461252
del buckets[split]
1253+
yield d
1254+
size += len(d)
1255+
1256+
avg = (size / n) >> 20
1257+
# let 1M < avg < 10M
1258+
if avg < 1:
1259+
batch *= 1.5
1260+
elif avg > 10:
1261+
batch = max(batch / 1.5, 1)
1262+
c = 0
12471263

12481264
for (split, items) in buckets.iteritems():
12491265
yield pack_long(split)

python/pyspark/shuffle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _get_dirs(self):
149149
""" get all the directories """
150150
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp/spark")
151151
dirs = path.split(",")
152-
return [os.path.join(d, "python", str(os.getpid()))
152+
return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
153153
for d in dirs]
154154

155155
def _get_spill_dir(self, n):

0 commit comments

Comments
 (0)