Skip to content

Commit 179aa75

Browse files
committed
Bunch of fixes for longer running jobs
1. Increase the timeout for socket connection to wait for long jobs 2. Add some profiling information in worker.R 3. Put temp file writes before stdin writes in RRDD.scala
1 parent 227ee42 commit 179aa75

File tree

3 files changed

+86
-58
lines changed

3 files changed

+86
-58
lines changed

pkg/R/sparkRClient.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Creates a SparkR client connection object
44
# if one doesn't already exist
5-
connectBackend <- function(hostname, port, timeout = 60) {
5+
connectBackend <- function(hostname, port, timeout = 6000) {
66
if (exists(".sparkRcon", envir = .sparkREnv)) {
77
cat("SparkRBackend client connection already exists\n")
88
return(get(".sparkRcon", envir = .sparkREnv))

pkg/inst/worker/worker.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Worker class
22

3+
begin <- proc.time()[3]
4+
35
# NOTE: We use "stdin" to get the process stdin instead of the command line
46
inputConStdin <- file("stdin", open = "rb")
57

@@ -65,6 +67,8 @@ numPartitions <- SparkR:::readInt(inputCon)
6567

6668
isEmpty <- SparkR:::readInt(inputCon)
6769

70+
metadataEnd <- proc.time()[3]
71+
6872
if (isEmpty != 0) {
6973

7074
if (numPartitions == -1) {
@@ -74,12 +78,15 @@ if (isEmpty != 0) {
7478
} else {
7579
data <- readLines(inputCon)
7680
}
81+
dataReadEnd <- proc.time()[3]
7782
output <- do.call(execFunctionName, list(splitIndex, data))
83+
computeEnd <- proc.time()[3]
7884
if (isOutputSerialized) {
7985
SparkR:::writeRawSerialize(outputCon, output)
8086
} else {
8187
SparkR:::writeStrings(outputCon, output)
8288
}
89+
writeEnd <- proc.time()[3]
8390
} else {
8491
if (isInputSerialized) {
8592
# Now read as many characters as described in funcLen
@@ -88,6 +95,7 @@ if (isEmpty != 0) {
8895
data <- readLines(inputCon)
8996
}
9097

98+
dataReadEnd <- proc.time()[3]
9199
res <- new.env()
92100

93101
# Step 1: hash the data to an environment
@@ -105,6 +113,8 @@ if (isEmpty != 0) {
105113
}
106114
invisible(lapply(data, hashTupleToEnvir))
107115

116+
computeEnd <- proc.time()[3]
117+
108118
# Step 2: write out all of the environment as key-value pairs.
109119
for (name in ls(res)) {
110120
SparkR:::writeInt(outputCon, 2L)
@@ -113,6 +123,7 @@ if (isEmpty != 0) {
113123
length(res[[name]]$data) <- res[[name]]$counter
114124
SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
115125
}
126+
writeEnd <- proc.time()[3]
116127
}
117128
}
118129

@@ -128,5 +139,13 @@ unlink(inFileName)
128139
# Restore stdout
129140
sink()
130141

142+
end <- proc.time()[3]
143+
144+
cat("stats: total ", (end-begin), "\n", file=stderr())
145+
cat("stats: metadata ", (metadataEnd-begin), "\n", file=stderr())
146+
cat("stats: input read ", (dataReadEnd-metadataEnd), "\n", file=stderr())
147+
cat("stats: compute ", (computeEnd-dataReadEnd), "\n", file=stderr())
148+
cat("stats: output write ", (writeEnd-computeEnd), "\n", file=stderr())
149+
131150
# Finally print the name of the output file
132151
cat(outputFileName, "\n")

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
117117
// Start a thread to feed the process input from our parent's iterator
118118
new Thread("stdin writer for R") {
119119
override def run() {
120-
SparkEnv.set(env)
121-
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
122-
val printOutStd = new PrintStream(streamStd)
123-
printOutStd.println(tempFileName)
124-
printOutStd.println(rLibDir)
125-
printOutStd.println(tempFileIn.getAbsolutePath())
126-
printOutStd.flush()
127-
128-
streamStd.close()
129-
130-
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
131-
val printOut = new PrintStream(stream)
132-
val dataOut = new DataOutputStream(stream)
133-
134-
dataOut.writeInt(splitIndex)
135-
136-
dataOut.writeInt(func.length)
137-
dataOut.write(func, 0, func.length)
138-
139-
// R worker process input serialization flag
140-
dataOut.writeInt(if (parentSerialized) 1 else 0)
141-
// R worker process output serialization flag
142-
dataOut.writeInt(if (dataSerialized) 1 else 0)
143-
144-
dataOut.writeInt(packageNames.length)
145-
dataOut.write(packageNames, 0, packageNames.length)
146-
147-
dataOut.writeInt(functionDependencies.length)
148-
dataOut.write(functionDependencies, 0, functionDependencies.length)
149-
150-
dataOut.writeInt(broadcastVars.length)
151-
broadcastVars.foreach { broadcast =>
152-
// TODO(shivaram): Read a Long in R to avoid this cast
153-
dataOut.writeInt(broadcast.id.toInt)
154-
// TODO: Pass a byte array from R to avoid this cast ?
155-
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
156-
dataOut.writeInt(broadcastByteArr.length)
157-
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
158-
}
159-
160-
dataOut.writeInt(numPartitions)
120+
try {
121+
SparkEnv.set(env)
122+
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
123+
val printOut = new PrintStream(stream)
124+
val dataOut = new DataOutputStream(stream)
125+
126+
dataOut.writeInt(splitIndex)
127+
128+
dataOut.writeInt(func.length)
129+
dataOut.write(func, 0, func.length)
130+
131+
// R worker process input serialization flag
132+
dataOut.writeInt(if (parentSerialized) 1 else 0)
133+
// R worker process output serialization flag
134+
dataOut.writeInt(if (dataSerialized) 1 else 0)
135+
136+
dataOut.writeInt(packageNames.length)
137+
dataOut.write(packageNames, 0, packageNames.length)
138+
139+
dataOut.writeInt(functionDependencies.length)
140+
dataOut.write(functionDependencies, 0, functionDependencies.length)
141+
142+
dataOut.writeInt(broadcastVars.length)
143+
broadcastVars.foreach { broadcast =>
144+
// TODO(shivaram): Read a Long in R to avoid this cast
145+
dataOut.writeInt(broadcast.id.toInt)
146+
// TODO: Pass a byte array from R to avoid this cast ?
147+
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
148+
dataOut.writeInt(broadcastByteArr.length)
149+
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
150+
}
161151

162-
if (!iter.hasNext) {
163-
dataOut.writeInt(0)
164-
} else {
165-
dataOut.writeInt(1)
166-
}
152+
dataOut.writeInt(numPartitions)
167153

168-
for (elem <- iter) {
169-
if (parentSerialized) {
170-
val elemArr = elem.asInstanceOf[Array[Byte]]
171-
dataOut.writeInt(elemArr.length)
172-
dataOut.write(elemArr, 0, elemArr.length)
154+
if (!iter.hasNext) {
155+
dataOut.writeInt(0)
173156
} else {
174-
printOut.println(elem)
157+
dataOut.writeInt(1)
158+
}
159+
160+
for (elem <- iter) {
161+
if (parentSerialized) {
162+
val elemArr = elem.asInstanceOf[Array[Byte]]
163+
dataOut.writeInt(elemArr.length)
164+
dataOut.write(elemArr, 0, elemArr.length)
165+
} else {
166+
printOut.println(elem)
167+
}
175168
}
176-
}
177169

178-
printOut.flush()
179-
dataOut.flush()
180-
stream.flush()
181-
stream.close()
170+
printOut.flush()
171+
dataOut.flush()
172+
stream.flush()
173+
stream.close()
174+
175+
// NOTE: We need to write out the temp file before writing out the
176+
// file name to stdin. Otherwise the R process could read partial state
177+
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
178+
val printOutStd = new PrintStream(streamStd)
179+
printOutStd.println(tempFileName)
180+
printOutStd.println(rLibDir)
181+
printOutStd.println(tempFileIn.getAbsolutePath())
182+
printOutStd.flush()
183+
184+
streamStd.close()
185+
} catch {
186+
// TODO: We should propogate this error to the task thread
187+
case e: Exception =>
188+
System.err.println("R Writer thread got an exception " + e)
189+
e.printStackTrace()
190+
}
182191
}
183192
}.start()
184193

0 commit comments

Comments
 (0)