Skip to content

Commit 8767565

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <[email protected]> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
1 parent 3cac199 commit 8767565

File tree

4 files changed

+82
-51
lines changed

4 files changed

+82
-51
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
1919

2020
import java.io._
2121
import java.net._
22-
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
23-
24-
import org.apache.spark.input.PortableDataStream
22+
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
2523

2624
import scala.collection.JavaConversions._
2725
import scala.collection.mutable
2826
import scala.language.existentials
2927

3028
import com.google.common.base.Charsets.UTF_8
31-
3229
import org.apache.hadoop.conf.Configuration
3330
import org.apache.hadoop.io.compress.CompressionCodec
34-
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
31+
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
3532
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
33+
3634
import org.apache.spark._
37-
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
35+
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
3836
import org.apache.spark.broadcast.Broadcast
37+
import org.apache.spark.input.PortableDataStream
3938
import org.apache.spark.rdd.RDD
4039
import org.apache.spark.util.Utils
4140

41+
import scala.util.control.NonFatal
42+
4243
private[spark] class PythonRDD(
4344
@transient parent: RDD[_],
4445
command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
341342
/**
342343
* Adapter for calling SparkContext#runJob from Python.
343344
*
344-
* This method will return an iterator of an array that contains all elements in the RDD
345+
* This method will serve an iterator of an array that contains all elements in the RDD
345346
* (effectively a collect()), but allows you to run on a certain subset of partitions,
346347
* or to enable local execution.
348+
*
349+
* @return the port number of a local socket which serves the data collected from this job.
347350
*/
348351
def runJob(
349352
sc: SparkContext,
350353
rdd: JavaRDD[Array[Byte]],
351354
partitions: JArrayList[Int],
352-
allowLocal: Boolean): Iterator[Array[Byte]] = {
355+
allowLocal: Boolean): Int = {
353356
type ByteArray = Array[Byte]
354357
type UnrolledPartition = Array[ByteArray]
355358
val allPartitions: Array[UnrolledPartition] =
356359
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
357360
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
358-
flattenedPartition.iterator
361+
serveIterator(flattenedPartition.iterator,
362+
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
363+
}
364+
365+
/**
366+
* A helper function to collect an RDD as an iterator, then serve it via socket.
367+
*
368+
* @return the port number of a local socket which serves the data collected from this job.
369+
*/
370+
def collectAndServe[T](rdd: RDD[T]): Int = {
371+
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
359372
}
360373

361374
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
575588
dataOut.write(bytes)
576589
}
577590

578-
def writeToFile[T](items: java.util.Iterator[T], filename: String) {
579-
import scala.collection.JavaConverters._
580-
writeToFile(items.asScala, filename)
581-
}
591+
/**
592+
* Create a socket server and a background thread to serve the data in `items`,
593+
*
594+
* The socket server can only accept one connection, or close if no connection
595+
* in 3 seconds.
596+
*
597+
* Once a connection comes in, it tries to serialize all the data in `items`
598+
* and send them into this connection.
599+
*
600+
* The thread will terminate after all the data are sent or any exceptions happen.
601+
*/
602+
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
603+
val serverSocket = new ServerSocket(0, 1)
604+
serverSocket.setReuseAddress(true)
605+
// Close the socket if no connection in 3 seconds
606+
serverSocket.setSoTimeout(3000)
607+
608+
new Thread(threadName) {
609+
setDaemon(true)
610+
override def run() {
611+
try {
612+
val sock = serverSocket.accept()
613+
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
614+
try {
615+
writeIteratorToStream(items, out)
616+
} finally {
617+
out.close()
618+
}
619+
} catch {
620+
case NonFatal(e) =>
621+
logError(s"Error while sending iterator", e)
622+
} finally {
623+
serverSocket.close()
624+
}
625+
}
626+
}.start()
582627

583-
def writeToFile[T](items: Iterator[T], filename: String) {
584-
val file = new DataOutputStream(new FileOutputStream(filename))
585-
writeIteratorToStream(items, file)
586-
file.close()
628+
serverSocket.getLocalPort
587629
}
588630

589631
private def getMergedConf(confAsMap: java.util.HashMap[String, String],

python/pyspark/context.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from threading import Lock
2222
from tempfile import NamedTemporaryFile
2323

24+
from py4j.java_collections import ListConverter
25+
2426
from pyspark import accumulators
2527
from pyspark.accumulators import Accumulator
2628
from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@
3032
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
3133
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
3234
from pyspark.storagelevel import StorageLevel
33-
from pyspark.rdd import RDD
35+
from pyspark.rdd import RDD, _load_from_socket
3436
from pyspark.traceback_utils import CallSite, first_spark_call
3537
from pyspark.status import StatusTracker
3638
from pyspark.profiler import ProfilerCollector, BasicProfiler
3739

38-
from py4j.java_collections import ListConverter
39-
4040

4141
__all__ = ['SparkContext']
4242

@@ -59,7 +59,6 @@ class SparkContext(object):
5959

6060
_gateway = None
6161
_jvm = None
62-
_writeToFile = None
6362
_next_accum_id = 0
6463
_active_spark_context = None
6564
_lock = Lock()
@@ -221,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
221220
if not SparkContext._gateway:
222221
SparkContext._gateway = gateway or launch_gateway()
223222
SparkContext._jvm = SparkContext._gateway.jvm
224-
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
225223

226224
if instance:
227225
if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
840838
# by runJob() in order to avoid having to pass a Python lambda into
841839
# SparkContext#runJob.
842840
mappedRDD = rdd.mapPartitions(partitionFunc)
843-
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
844-
return list(mappedRDD._collect_iterator_through_file(it))
841+
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
842+
allowLocal)
843+
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
845844

846845
def show_profiles(self):
847846
""" Print the profile stats to stdout """

python/pyspark/rdd.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from collections import defaultdict
2020
from itertools import chain, ifilter, imap
2121
import operator
22-
import os
2322
import sys
2423
import shlex
2524
from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@
2928
import heapq
3029
import bisect
3130
import random
31+
import socket
3232
from math import sqrt, log, isinf, isnan, pow, ceil
3333

3434
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
111111
return int(float(s[:-1]) * units[s[-1].lower()])
112112

113113

114+
def _load_from_socket(port, serializer):
115+
sock = socket.socket()
116+
try:
117+
sock.connect(("localhost", port))
118+
rf = sock.makefile("rb", 65536)
119+
for item in serializer.load_stream(rf):
120+
yield item
121+
finally:
122+
sock.close()
123+
124+
114125
class Partitioner(object):
115126
def __init__(self, numPartitions, partitionFunc):
116127
self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ def collect(self):
698709
Return a list that contains all of the elements in this RDD.
699710
"""
700711
with SCCallSiteSync(self.context) as css:
701-
bytesInJava = self._jrdd.collect().iterator()
702-
return list(self._collect_iterator_through_file(bytesInJava))
703-
704-
def _collect_iterator_through_file(self, iterator):
705-
# Transferring lots of data through Py4J can be slow because
706-
# socket.readline() is inefficient. Instead, we'll dump the data to a
707-
# file and read it back.
708-
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
709-
tempFile.close()
710-
self.ctx._writeToFile(iterator, tempFile.name)
711-
# Read the data into Python and deserialize it:
712-
with open(tempFile.name, 'rb') as tempFile:
713-
for item in self._jrdd_deserializer.load_stream(tempFile):
714-
yield item
715-
os.unlink(tempFile.name)
712+
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
713+
return list(_load_from_socket(port, self._jrdd_deserializer))
716714

717715
def reduce(self, f):
718716
"""

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
import itertools
2020
import warnings
2121
import random
22-
import os
23-
from tempfile import NamedTemporaryFile
2422

2523
from py4j.java_collections import ListConverter, MapConverter
2624

2725
from pyspark.context import SparkContext
28-
from pyspark.rdd import RDD
26+
from pyspark.rdd import RDD, _load_from_socket
2927
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
3028
from pyspark.storagelevel import StorageLevel
3129
from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ def collect(self):
310308
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
311309
"""
312310
with SCCallSiteSync(self._sc) as css:
313-
bytesInJava = self._jdf.javaToPython().collect().iterator()
314-
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
315-
tempFile.close()
316-
self._sc._writeToFile(bytesInJava, tempFile.name)
317-
# Read the data into Python and deserialize it:
318-
with open(tempFile.name, 'rb') as tempFile:
319-
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
320-
os.unlink(tempFile.name)
311+
port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
312+
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
321313
cls = _create_cls(self.schema)
322314
return [cls(r) for r in rs]
323315

0 commit comments

Comments
 (0)