@@ -19,26 +19,27 @@ package org.apache.spark.api.python
19
19
20
20
import java .io ._
21
21
import java .net ._
22
- import java .util .{List => JList , ArrayList => JArrayList , Map => JMap , UUID , Collections }
23
-
24
- import org .apache .spark .input .PortableDataStream
22
+ import java .util .{Collections , ArrayList => JArrayList , List => JList , Map => JMap }
25
23
26
24
import scala .collection .JavaConversions ._
27
25
import scala .collection .mutable
28
26
import scala .language .existentials
29
27
30
28
import com .google .common .base .Charsets .UTF_8
31
-
32
29
import org .apache .hadoop .conf .Configuration
33
30
import org .apache .hadoop .io .compress .CompressionCodec
34
- import org .apache .hadoop .mapred .{InputFormat , OutputFormat , JobConf }
31
+ import org .apache .hadoop .mapred .{InputFormat , JobConf , OutputFormat }
35
32
import org .apache .hadoop .mapreduce .{InputFormat => NewInputFormat , OutputFormat => NewOutputFormat }
33
+
36
34
import org .apache .spark ._
37
- import org .apache .spark .api .java .{JavaSparkContext , JavaPairRDD , JavaRDD }
35
+ import org .apache .spark .api .java .{JavaPairRDD , JavaRDD , JavaSparkContext }
38
36
import org .apache .spark .broadcast .Broadcast
37
+ import org .apache .spark .input .PortableDataStream
39
38
import org .apache .spark .rdd .RDD
40
39
import org .apache .spark .util .Utils
41
40
41
+ import scala .util .control .NonFatal
42
+
42
43
private [spark] class PythonRDD (
43
44
@ transient parent : RDD [_],
44
45
command : Array [Byte ],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
341
342
/**
342
343
* Adapter for calling SparkContext#runJob from Python.
343
344
*
344
- * This method will return an iterator of an array that contains all elements in the RDD
345
+ * This method will serve an iterator of an array that contains all elements in the RDD
345
346
* (effectively a collect()), but allows you to run on a certain subset of partitions,
346
347
* or to enable local execution.
348
+ *
349
+ * @return the port number of a local socket which serves the data collected from this job.
347
350
*/
348
351
def runJob (
349
352
sc : SparkContext ,
350
353
rdd : JavaRDD [Array [Byte ]],
351
354
partitions : JArrayList [Int ],
352
- allowLocal : Boolean ): Iterator [ Array [ Byte ]] = {
355
+ allowLocal : Boolean ): Int = {
353
356
type ByteArray = Array [Byte ]
354
357
type UnrolledPartition = Array [ByteArray ]
355
358
val allPartitions : Array [UnrolledPartition ] =
356
359
sc.runJob(rdd, (x : Iterator [ByteArray ]) => x.toArray, partitions, allowLocal)
357
360
val flattenedPartition : UnrolledPartition = Array .concat(allPartitions : _* )
358
- flattenedPartition.iterator
361
+ serveIterator(flattenedPartition.iterator,
362
+ s " serve RDD ${rdd.id} with partitions ${partitions.mkString(" ," )}" )
363
+ }
364
+
365
+ /**
366
+ * A helper function to collect an RDD as an iterator, then serve it via socket.
367
+ *
368
+ * @return the port number of a local socket which serves the data collected from this job.
369
+ */
370
+ def collectAndServe [T ](rdd : RDD [T ]): Int = {
371
+ serveIterator(rdd.collect().iterator, s " serve RDD ${rdd.id}" )
359
372
}
360
373
361
374
def readRDDFromFile (sc : JavaSparkContext , filename : String , parallelism : Int ):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
575
588
dataOut.write(bytes)
576
589
}
577
590
578
- def writeToFile [T ](items : java.util.Iterator [T ], filename : String ) {
579
- import scala .collection .JavaConverters ._
580
- writeToFile(items.asScala, filename)
581
- }
591
+ /**
592
+ * Create a socket server and a background thread to serve the data in `items`,
593
+ *
594
+ * The socket server can only accept one connection, or close if no connection
595
+ * in 3 seconds.
596
+ *
597
+ * Once a connection comes in, it tries to serialize all the data in `items`
598
+ * and send them into this connection.
599
+ *
600
+ * The thread will terminate after all the data are sent or any exceptions happen.
601
+ */
602
+ private def serveIterator [T ](items : Iterator [T ], threadName : String ): Int = {
603
+ val serverSocket = new ServerSocket (0 , 1 )
604
+ serverSocket.setReuseAddress(true )
605
+ // Close the socket if no connection in 3 seconds
606
+ serverSocket.setSoTimeout(3000 )
607
+
608
+ new Thread (threadName) {
609
+ setDaemon(true )
610
+ override def run () {
611
+ try {
612
+ val sock = serverSocket.accept()
613
+ val out = new DataOutputStream (new BufferedOutputStream (sock.getOutputStream))
614
+ try {
615
+ writeIteratorToStream(items, out)
616
+ } finally {
617
+ out.close()
618
+ }
619
+ } catch {
620
+ case NonFatal (e) =>
621
+ logError(s " Error while sending iterator " , e)
622
+ } finally {
623
+ serverSocket.close()
624
+ }
625
+ }
626
+ }.start()
582
627
583
- def writeToFile [T ](items : Iterator [T ], filename : String ) {
584
- val file = new DataOutputStream (new FileOutputStream (filename))
585
- writeIteratorToStream(items, file)
586
- file.close()
628
+ serverSocket.getLocalPort
587
629
}
588
630
589
631
private def getMergedConf (confAsMap : java.util.HashMap [String , String ],
0 commit comments