@@ -1227,23 +1227,39 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
1227
1227
1228
1228
# Transferring O(n) objects to Java is too expensive. Instead, we'll
1229
1229
# form the hash buckets in Python, transferring O(numPartitions) objects
1230
- # to Java. Each object is a (splitNumber, [objects]) pair.
1230
+ # to Java. Each object is a (splitNumber, [objects]) pair.
1231
+ # In order to void too huge objects, the objects are grouped into chunks.
1231
1232
outputSerializer = self .ctx ._unbatched_serializer
1232
1233
1233
- limit = _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" )
1234
- or "512m" )
1234
+ limit = ( _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" )
1235
+ or "512m" ) / 2 )
1235
1236
def add_shuffle_key (split , iterator ):
1236
1237
1237
1238
buckets = defaultdict (list )
1238
- c , batch = 0 , 1000
1239
+ c , batch = 0 , min (10 * numPartitions , 1000 )
1240
+
1239
1241
for (k , v ) in iterator :
1240
1242
buckets [partitionFunc (k ) % numPartitions ].append ((k , v ))
1241
1243
c += 1
1242
- if c % batch == 0 and get_used_memory () > limit :
1244
+
1245
+ # check used memory and avg size of chunk of objects
1246
+ if (c % 1000 == 0 and get_used_memory () > limit
1247
+ or c > batch ):
1248
+ n , size = len (buckets ), 0
1243
1249
for split in buckets .keys ():
1244
1250
yield pack_long (split )
1245
- yield outputSerializer .dumps (buckets [split ])
1251
+ d = outputSerializer .dumps (buckets [split ])
1246
1252
del buckets [split ]
1253
+ yield d
1254
+ size += len (d )
1255
+
1256
+ avg = (size / n ) >> 20
1257
+ # let 1M < avg < 10M
1258
+ if avg < 1 :
1259
+ batch *= 1.5
1260
+ elif avg > 10 :
1261
+ batch = max (batch / 1.5 , 1 )
1262
+ c = 0
1247
1263
1248
1264
for (split , items ) in buckets .iteritems ():
1249
1265
yield pack_long (split )
0 commit comments