@@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
117
117
// Start a thread to feed the process input from our parent's iterator
118
118
new Thread (" stdin writer for R" ) {
119
119
override def run () {
120
- SparkEnv .set(env)
121
- val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
122
- val printOutStd = new PrintStream (streamStd)
123
- printOutStd.println(tempFileName)
124
- printOutStd.println(rLibDir)
125
- printOutStd.println(tempFileIn.getAbsolutePath())
126
- printOutStd.flush()
127
-
128
- streamStd.close()
129
-
130
- val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
131
- val printOut = new PrintStream (stream)
132
- val dataOut = new DataOutputStream (stream)
133
-
134
- dataOut.writeInt(splitIndex)
135
-
136
- dataOut.writeInt(func.length)
137
- dataOut.write(func, 0 , func.length)
138
-
139
- // R worker process input serialization flag
140
- dataOut.writeInt(if (parentSerialized) 1 else 0 )
141
- // R worker process output serialization flag
142
- dataOut.writeInt(if (dataSerialized) 1 else 0 )
143
-
144
- dataOut.writeInt(packageNames.length)
145
- dataOut.write(packageNames, 0 , packageNames.length)
146
-
147
- dataOut.writeInt(functionDependencies.length)
148
- dataOut.write(functionDependencies, 0 , functionDependencies.length)
149
-
150
- dataOut.writeInt(broadcastVars.length)
151
- broadcastVars.foreach { broadcast =>
152
- // TODO(shivaram): Read a Long in R to avoid this cast
153
- dataOut.writeInt(broadcast.id.toInt)
154
- // TODO: Pass a byte array from R to avoid this cast ?
155
- val broadcastByteArr = broadcast.value.asInstanceOf [Array [Byte ]]
156
- dataOut.writeInt(broadcastByteArr.length)
157
- dataOut.write(broadcastByteArr, 0 , broadcastByteArr.length)
158
- }
159
-
160
- dataOut.writeInt(numPartitions)
120
+ try {
121
+ SparkEnv .set(env)
122
+ val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
123
+ val printOut = new PrintStream (stream)
124
+ val dataOut = new DataOutputStream (stream)
125
+
126
+ dataOut.writeInt(splitIndex)
127
+
128
+ dataOut.writeInt(func.length)
129
+ dataOut.write(func, 0 , func.length)
130
+
131
+ // R worker process input serialization flag
132
+ dataOut.writeInt(if (parentSerialized) 1 else 0 )
133
+ // R worker process output serialization flag
134
+ dataOut.writeInt(if (dataSerialized) 1 else 0 )
135
+
136
+ dataOut.writeInt(packageNames.length)
137
+ dataOut.write(packageNames, 0 , packageNames.length)
138
+
139
+ dataOut.writeInt(functionDependencies.length)
140
+ dataOut.write(functionDependencies, 0 , functionDependencies.length)
141
+
142
+ dataOut.writeInt(broadcastVars.length)
143
+ broadcastVars.foreach { broadcast =>
144
+ // TODO(shivaram): Read a Long in R to avoid this cast
145
+ dataOut.writeInt(broadcast.id.toInt)
146
+ // TODO: Pass a byte array from R to avoid this cast ?
147
+ val broadcastByteArr = broadcast.value.asInstanceOf [Array [Byte ]]
148
+ dataOut.writeInt(broadcastByteArr.length)
149
+ dataOut.write(broadcastByteArr, 0 , broadcastByteArr.length)
150
+ }
161
151
162
- if (! iter.hasNext) {
163
- dataOut.writeInt(0 )
164
- } else {
165
- dataOut.writeInt(1 )
166
- }
152
+ dataOut.writeInt(numPartitions)
167
153
168
- for (elem <- iter) {
169
- if (parentSerialized) {
170
- val elemArr = elem.asInstanceOf [Array [Byte ]]
171
- dataOut.writeInt(elemArr.length)
172
- dataOut.write(elemArr, 0 , elemArr.length)
154
+ if (! iter.hasNext) {
155
+ dataOut.writeInt(0 )
173
156
} else {
174
- printOut.println(elem)
157
+ dataOut.writeInt(1 )
158
+ }
159
+
160
+ for (elem <- iter) {
161
+ if (parentSerialized) {
162
+ val elemArr = elem.asInstanceOf [Array [Byte ]]
163
+ dataOut.writeInt(elemArr.length)
164
+ dataOut.write(elemArr, 0 , elemArr.length)
165
+ } else {
166
+ printOut.println(elem)
167
+ }
175
168
}
176
- }
177
169
178
- printOut.flush()
179
- dataOut.flush()
180
- stream.flush()
181
- stream.close()
170
+ printOut.flush()
171
+ dataOut.flush()
172
+ stream.flush()
173
+ stream.close()
174
+
175
+ // NOTE: We need to write out the temp file before writing out the
176
+ // file name to stdin. Otherwise the R process could read partial state
177
+ val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
178
+ val printOutStd = new PrintStream (streamStd)
179
+ printOutStd.println(tempFileName)
180
+ printOutStd.println(rLibDir)
181
+ printOutStd.println(tempFileIn.getAbsolutePath())
182
+ printOutStd.flush()
183
+
184
+ streamStd.close()
185
+ } catch {
186
+ // TODO: We should propogate this error to the task thread
187
+ case e : Exception =>
188
+ System .err.println(" R Writer thread got an exception " + e)
189
+ e.printStackTrace()
190
+ }
182
191
}
183
192
}.start()
184
193
0 commit comments