1
1
package edu .berkeley .cs .amplab .sparkr
2
2
3
3
import java .io ._
4
+ import java .net .{ServerSocket }
4
5
import java .util .{Map => JMap }
5
6
6
7
import scala .collection .JavaConversions ._
7
8
import scala .io .Source
8
9
import scala .reflect .ClassTag
10
+ import scala .util .Try
9
11
10
12
import org .apache .spark .{SparkEnv , Partition , SparkException , TaskContext , SparkConf }
11
13
import org .apache .spark .api .java .{JavaSparkContext , JavaRDD , JavaPairRDD }
12
14
import org .apache .spark .broadcast .Broadcast
13
15
import org .apache .spark .rdd .RDD
14
16
17
+
15
18
private abstract class BaseRRDD [T : ClassTag , U : ClassTag ](
16
19
parent : RDD [T ],
17
20
numPartitions : Int ,
@@ -27,21 +30,35 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
27
30
28
31
override def compute (split : Partition , context : TaskContext ): Iterator [U ] = {
29
32
33
+ // The parent may be also an RRDD, so we should launch it first.
30
34
val parentIterator = firstParent[T ].iterator(split, context)
31
35
32
- val pb = rWorkerProcessBuilder()
36
+ // we expect two connections
37
+ val serverSocket = new ServerSocket (0 , 2 )
38
+ val listenPort = serverSocket.getLocalPort()
39
+
40
+ val pb = rWorkerProcessBuilder(listenPort)
41
+ pb.redirectErrorStream() // redirect stderr into stdout
33
42
val proc = pb.start()
43
+ val errThread = startStdoutThread(proc)
44
+
45
+ // We use two sockets to separate input and output, then it's easy to manage
46
+ // the lifecycle of them to avoid deadlock.
47
+ // TODO: optimize it to use one socket
34
48
35
- val errThread = startStderrThread(proc)
49
+ // the socket used to send out the input of task
50
+ serverSocket.setSoTimeout(10000 )
51
+ val inSocket = serverSocket.accept()
52
+ startStdinThread(inSocket.getOutputStream(), parentIterator, split.index)
36
53
37
- val tempFile = startStdinThread(proc, parentIterator, split.index)
54
+ // the socket used to receive the output of task
55
+ val outSocket = serverSocket.accept()
56
+ val inputStream = new BufferedInputStream (outSocket.getInputStream)
57
+ val dataStream = openDataStream(inputStream)
38
58
39
- // Return an iterator that read lines from the process's stdout
40
- val inputStream = new BufferedReader (new InputStreamReader (proc.getInputStream))
59
+ serverSocket.close()
41
60
42
61
try {
43
- val stdOutFileName = inputStream.readLine().trim()
44
- val dataStream = openDataStream(stdOutFileName)
45
62
46
63
return new Iterator [U ] {
47
64
def next (): U = {
@@ -57,9 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
57
74
def hasNext (): Boolean = {
58
75
val hasMore = (_nextObj != null )
59
76
if (! hasMore) {
60
- // Delete the temporary file we created as we are done reading it
61
77
dataStream.close()
62
- tempFile.delete()
63
78
}
64
79
hasMore
65
80
}
@@ -73,7 +88,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
73
88
/**
74
89
* ProcessBuilder used to launch worker R processes.
75
90
*/
76
- private def rWorkerProcessBuilder () = {
91
+ private def rWorkerProcessBuilder (port : Int ) = {
77
92
val rCommand = " Rscript"
78
93
val rOptions = " --vanilla"
79
94
val rExecScript = rLibDir + " /SparkR/worker/worker.R"
@@ -82,47 +97,42 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
82
97
// This is set by R CMD check as startup.Rs
83
98
// (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
84
99
// and confuses worker script which tries to load a non-existent file
85
- pb.environment().put(" R_TESTS" , " " );
100
+ pb.environment().put(" R_TESTS" , " " )
101
+ pb.environment().put(" SPARKR_WORKER_PORT" , port.toString)
86
102
pb
87
103
}
88
104
89
105
/**
90
106
* Start a thread to print the process's stderr to ours
91
107
*/
92
- private def startStderrThread (proc : Process ): BufferedStreamThread = {
93
- val ERR_BUFFER_SIZE = 100
94
- val errThread = new BufferedStreamThread (proc.getErrorStream , " stderr reader for R" ,
95
- ERR_BUFFER_SIZE )
96
- errThread .start()
97
- errThread
108
+ private def startStdoutThread (proc : Process ): BufferedStreamThread = {
109
+ val BUFFER_SIZE = 100
110
+ val thread = new BufferedStreamThread (proc.getInputStream , " stdout reader for R" , BUFFER_SIZE )
111
+ thread.setDaemon( true )
112
+ thread .start()
113
+ thread
98
114
}
99
115
100
116
/**
101
117
* Start a thread to write RDD data to the R process.
102
118
*/
103
119
private def startStdinThread [T ](
104
- proc : Process ,
120
+ output : OutputStream ,
105
121
iter : Iterator [T ],
106
- splitIndex : Int ) : File = {
122
+ splitIndex : Int ) = {
107
123
108
124
val env = SparkEnv .get
109
- val conf = env.conf
110
- val tempDir = RRDD .getLocalDir(conf)
111
- val tempFile = File .createTempFile(" rSpark" , " out" , new File (tempDir))
112
- val tempFileIn = File .createTempFile(" rSpark" , " in" , new File (tempDir))
113
-
114
- val tempFileName = tempFile.getAbsolutePath()
115
125
val bufferSize = System .getProperty(" spark.buffer.size" , " 65536" ).toInt
126
+ val stream = new BufferedOutputStream (output, bufferSize)
116
127
117
- // Start a thread to feed the process input from our parent's iterator
118
- new Thread (" stdin writer for R" ) {
128
+ new Thread (" writer for R" ) {
119
129
override def run () {
120
130
try {
121
131
SparkEnv .set(env)
122
- val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
123
132
val printOut = new PrintStream (stream)
124
- val dataOut = new DataOutputStream (stream )
133
+ printOut.println(rLibDir )
125
134
135
+ val dataOut = new DataOutputStream (stream)
126
136
dataOut.writeInt(splitIndex)
127
137
128
138
dataOut.writeInt(func.length)
@@ -166,35 +176,21 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
166
176
printOut.println(elem)
167
177
}
168
178
}
169
-
170
- printOut.flush()
171
- dataOut.flush()
172
179
stream.flush()
173
- stream.close()
174
-
175
- // NOTE: We need to write out the temp file before writing out the
176
- // file name to stdin. Otherwise the R process could read partial state
177
- val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
178
- val printOutStd = new PrintStream (streamStd)
179
- printOutStd.println(tempFileName)
180
- printOutStd.println(rLibDir)
181
- printOutStd.println(tempFileIn.getAbsolutePath())
182
- printOutStd.flush()
183
-
184
- streamStd.close()
185
180
} catch {
186
181
// TODO: We should propogate this error to the task thread
187
182
case e : Exception =>
188
183
System .err.println(" R Writer thread got an exception " + e)
189
184
e.printStackTrace()
185
+ } finally {
186
+ Try (output.close())
190
187
}
191
188
}
192
189
}.start()
193
-
194
- tempFile
195
190
}
196
191
197
- protected def openDataStream (stdOutFileName : String ): Closeable
192
+ protected def openDataStream (input : InputStream ): Closeable
193
+
198
194
protected def read (): U
199
195
}
200
196
@@ -217,8 +213,8 @@ private class PairwiseRRDD[T: ClassTag](
217
213
218
214
private var dataStream : DataInputStream = _
219
215
220
- override protected def openDataStream (stdOutFileName : String ) = {
221
- dataStream = new DataInputStream (new FileInputStream (stdOutFileName) )
216
+ override protected def openDataStream (input : InputStream ) = {
217
+ dataStream = new DataInputStream (input )
222
218
dataStream
223
219
}
224
220
@@ -261,9 +257,9 @@ private class RRDD[T: ClassTag](
261
257
broadcastVars.map(x => x.asInstanceOf [Broadcast [Object ]])) {
262
258
263
259
private var dataStream : DataInputStream = _
264
-
265
- override protected def openDataStream (stdOutFileName : String ) = {
266
- dataStream = new DataInputStream (new FileInputStream (stdOutFileName) )
260
+
261
+ override protected def openDataStream (input : InputStream ) = {
262
+ dataStream = new DataInputStream (input )
267
263
dataStream
268
264
}
269
265
@@ -305,9 +301,8 @@ private class StringRRDD[T: ClassTag](
305
301
306
302
private var dataStream : BufferedReader = _
307
303
308
- override protected def openDataStream (stdOutFileName : String ) = {
309
- dataStream = new BufferedReader (
310
- new InputStreamReader (new FileInputStream (stdOutFileName)))
304
+ override protected def openDataStream (input : InputStream ) = {
305
+ dataStream = new BufferedReader (new InputStreamReader (input))
311
306
dataStream
312
307
}
313
308
@@ -334,6 +329,7 @@ private class BufferedStreamThread(
334
329
for (line <- Source .fromInputStream(in).getLines) {
335
330
lines(lineIdx) = line
336
331
lineIdx = (lineIdx + 1 ) % errBufferSize
332
+ // TODO: user logger
337
333
System .err.println(line)
338
334
}
339
335
}
0 commit comments