Skip to content

Commit 9aa4acf

Browse files
committed
Merge pull request apache#184 from davies/socket
[SPARKR-155] use socket in R worker
2 parents 5300766 + e776324 commit 9aa4acf

File tree

2 files changed

+55
-75
lines changed

2 files changed

+55
-75
lines changed

pkg/inst/worker/worker.R

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
# Worker class
22

3-
# NOTE: We use "stdin" to get the process stdin instead of the command line
4-
inputConStdin <- file("stdin", open = "rb")
3+
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
54

6-
outputFileName <- readLines(inputConStdin, n = 1)
7-
outputCon <- file(outputFileName, open="wb")
5+
inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
6+
outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
87

98
# Set libPaths to include SparkR package as loadNamespace needs this
109
# TODO: Figure out if we can avoid this by not loading any objects that require
1110
# SparkR namespace
12-
rLibDir <- readLines(inputConStdin, n = 1)
11+
rLibDir <- readLines(inputCon, n = 1)
1312
.libPaths(c(rLibDir, .libPaths()))
1413

1514
suppressPackageStartupMessages(library(SparkR))
1615

17-
inFileName <- readLines(inputConStdin, n = 1)
18-
19-
inputCon <- file(inFileName, open = "rb")
20-
2116
# read the index of the current partition inside the RDD
2217
splitIndex <- SparkR:::readInt(inputCon)
2318

@@ -31,10 +26,6 @@ isInputSerialized <- SparkR:::readInt(inputCon)
3126
# read the isOutputSerialized bit flag
3227
isOutputSerialized <- SparkR:::readInt(inputCon)
3328

34-
# Redirect stdout to stderr to prevent print statements from
35-
# interfering with outputStream
36-
sink(stderr())
37-
3829
# Include packages as required
3930
packageNames <- unserialize(SparkR:::readRaw(inputCon))
4031
for (pkg in packageNames) {
@@ -123,10 +114,3 @@ if (isOutputSerialized) {
123114

124115
close(outputCon)
125116
close(inputCon)
126-
unlink(inFileName)
127-
128-
# Restore stdout
129-
sink()
130-
131-
# Finally print the name of the output file
132-
cat(outputFileName, "\n")

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
package edu.berkeley.cs.amplab.sparkr
22

33
import java.io._
4+
import java.net.{ServerSocket}
45
import java.util.{Map => JMap}
56

67
import scala.collection.JavaConversions._
78
import scala.io.Source
89
import scala.reflect.ClassTag
10+
import scala.util.Try
911

1012
import org.apache.spark.{SparkEnv, Partition, SparkException, TaskContext, SparkConf}
1113
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD, JavaPairRDD}
1214
import org.apache.spark.broadcast.Broadcast
1315
import org.apache.spark.rdd.RDD
1416

17+
1518
private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
1619
parent: RDD[T],
1720
numPartitions: Int,
@@ -27,21 +30,35 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
2730

2831
override def compute(split: Partition, context: TaskContext): Iterator[U] = {
2932

33+
// The parent may be also an RRDD, so we should launch it first.
3034
val parentIterator = firstParent[T].iterator(split, context)
3135

32-
val pb = rWorkerProcessBuilder()
36+
// we expect two connections
37+
val serverSocket = new ServerSocket(0, 2)
38+
val listenPort = serverSocket.getLocalPort()
39+
40+
val pb = rWorkerProcessBuilder(listenPort)
41+
pb.redirectErrorStream() // redirect stderr into stdout
3342
val proc = pb.start()
43+
val errThread = startStdoutThread(proc)
44+
45+
// We use two sockets to separate input and output, then it's easy to manage
46+
// the lifecycle of them to avoid deadlock.
47+
// TODO: optimize it to use one socket
3448

35-
val errThread = startStderrThread(proc)
49+
// the socket used to send out the input of task
50+
serverSocket.setSoTimeout(10000)
51+
val inSocket = serverSocket.accept()
52+
startStdinThread(inSocket.getOutputStream(), parentIterator, split.index)
3653

37-
val tempFile = startStdinThread(proc, parentIterator, split.index)
54+
// the socket used to receive the output of task
55+
val outSocket = serverSocket.accept()
56+
val inputStream = new BufferedInputStream(outSocket.getInputStream)
57+
val dataStream = openDataStream(inputStream)
3858

39-
// Return an iterator that read lines from the process's stdout
40-
val inputStream = new BufferedReader(new InputStreamReader(proc.getInputStream))
59+
serverSocket.close()
4160

4261
try {
43-
val stdOutFileName = inputStream.readLine().trim()
44-
val dataStream = openDataStream(stdOutFileName)
4562

4663
return new Iterator[U] {
4764
def next(): U = {
@@ -57,9 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
5774
def hasNext(): Boolean = {
5875
val hasMore = (_nextObj != null)
5976
if (!hasMore) {
60-
// Delete the temporary file we created as we are done reading it
6177
dataStream.close()
62-
tempFile.delete()
6378
}
6479
hasMore
6580
}
@@ -73,7 +88,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
7388
/**
7489
* ProcessBuilder used to launch worker R processes.
7590
*/
76-
private def rWorkerProcessBuilder() = {
91+
private def rWorkerProcessBuilder(port: Int) = {
7792
val rCommand = "Rscript"
7893
val rOptions = "--vanilla"
7994
val rExecScript = rLibDir + "/SparkR/worker/worker.R"
@@ -82,47 +97,42 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
8297
// This is set by R CMD check as startup.Rs
8398
// (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
8499
// and confuses worker script which tries to load a non-existent file
85-
pb.environment().put("R_TESTS", "");
100+
pb.environment().put("R_TESTS", "")
101+
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
86102
pb
87103
}
88104

89105
/**
90106
* Start a thread to print the process's stderr to ours
91107
*/
92-
private def startStderrThread(proc: Process): BufferedStreamThread = {
93-
val ERR_BUFFER_SIZE = 100
94-
val errThread = new BufferedStreamThread(proc.getErrorStream, "stderr reader for R",
95-
ERR_BUFFER_SIZE)
96-
errThread.start()
97-
errThread
108+
private def startStdoutThread(proc: Process): BufferedStreamThread = {
109+
val BUFFER_SIZE = 100
110+
val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
111+
thread.setDaemon(true)
112+
thread.start()
113+
thread
98114
}
99115

100116
/**
101117
* Start a thread to write RDD data to the R process.
102118
*/
103119
private def startStdinThread[T](
104-
proc: Process,
120+
output: OutputStream,
105121
iter: Iterator[T],
106-
splitIndex: Int) : File = {
122+
splitIndex: Int) = {
107123

108124
val env = SparkEnv.get
109-
val conf = env.conf
110-
val tempDir = RRDD.getLocalDir(conf)
111-
val tempFile = File.createTempFile("rSpark", "out", new File(tempDir))
112-
val tempFileIn = File.createTempFile("rSpark", "in", new File(tempDir))
113-
114-
val tempFileName = tempFile.getAbsolutePath()
115125
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
126+
val stream = new BufferedOutputStream(output, bufferSize)
116127

117-
// Start a thread to feed the process input from our parent's iterator
118-
new Thread("stdin writer for R") {
128+
new Thread("writer for R") {
119129
override def run() {
120130
try {
121131
SparkEnv.set(env)
122-
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
123132
val printOut = new PrintStream(stream)
124-
val dataOut = new DataOutputStream(stream)
133+
printOut.println(rLibDir)
125134

135+
val dataOut = new DataOutputStream(stream)
126136
dataOut.writeInt(splitIndex)
127137

128138
dataOut.writeInt(func.length)
@@ -166,35 +176,21 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
166176
printOut.println(elem)
167177
}
168178
}
169-
170-
printOut.flush()
171-
dataOut.flush()
172179
stream.flush()
173-
stream.close()
174-
175-
// NOTE: We need to write out the temp file before writing out the
176-
// file name to stdin. Otherwise the R process could read partial state
177-
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
178-
val printOutStd = new PrintStream(streamStd)
179-
printOutStd.println(tempFileName)
180-
printOutStd.println(rLibDir)
181-
printOutStd.println(tempFileIn.getAbsolutePath())
182-
printOutStd.flush()
183-
184-
streamStd.close()
185180
} catch {
186181
// TODO: We should propogate this error to the task thread
187182
case e: Exception =>
188183
System.err.println("R Writer thread got an exception " + e)
189184
e.printStackTrace()
185+
} finally {
186+
Try(output.close())
190187
}
191188
}
192189
}.start()
193-
194-
tempFile
195190
}
196191

197-
protected def openDataStream(stdOutFileName: String): Closeable
192+
protected def openDataStream(input: InputStream): Closeable
193+
198194
protected def read(): U
199195
}
200196

@@ -217,8 +213,8 @@ private class PairwiseRRDD[T: ClassTag](
217213

218214
private var dataStream: DataInputStream = _
219215

220-
override protected def openDataStream(stdOutFileName: String) = {
221-
dataStream = new DataInputStream(new FileInputStream(stdOutFileName))
216+
override protected def openDataStream(input: InputStream) = {
217+
dataStream = new DataInputStream(input)
222218
dataStream
223219
}
224220

@@ -261,9 +257,9 @@ private class RRDD[T: ClassTag](
261257
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
262258

263259
private var dataStream: DataInputStream = _
264-
265-
override protected def openDataStream(stdOutFileName: String) = {
266-
dataStream = new DataInputStream(new FileInputStream(stdOutFileName))
260+
261+
override protected def openDataStream(input: InputStream) = {
262+
dataStream = new DataInputStream(input)
267263
dataStream
268264
}
269265

@@ -305,9 +301,8 @@ private class StringRRDD[T: ClassTag](
305301

306302
private var dataStream: BufferedReader = _
307303

308-
override protected def openDataStream(stdOutFileName: String) = {
309-
dataStream = new BufferedReader(
310-
new InputStreamReader(new FileInputStream(stdOutFileName)))
304+
override protected def openDataStream(input: InputStream) = {
305+
dataStream = new BufferedReader(new InputStreamReader(input))
311306
dataStream
312307
}
313308

@@ -334,6 +329,7 @@ private class BufferedStreamThread(
334329
for (line <- Source.fromInputStream(in).getLines) {
335330
lines(lineIdx) = line
336331
lineIdx = (lineIdx + 1) % errBufferSize
332+
// TODO: user logger
337333
System.err.println(line)
338334
}
339335
}

0 commit comments

Comments
 (0)