@@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
42
42
rLibDir : String ,
43
43
broadcastVars : Array [Broadcast [Object ]])
44
44
extends RDD [U ](parent) with Logging {
45
+ protected var dataStream : DataInputStream = _
46
+ private var bootTime : Double = _
45
47
override def getPartitions : Array [Partition ] = parent.partitions
46
48
47
49
override def compute (partition : Partition , context : TaskContext ): Iterator [U ] = {
48
50
51
+ // Timing start
52
+ bootTime = System .currentTimeMillis / 1000.0
53
+
49
54
// The parent may be also an RRDD, so we should launch it first.
50
55
val parentIterator = firstParent[T ].iterator(partition, context)
51
56
@@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
69
74
// the socket used to receive the output of task
70
75
val outSocket = serverSocket.accept()
71
76
val inputStream = new BufferedInputStream (outSocket.getInputStream)
72
- val dataStream = openDataStream (inputStream)
77
+ dataStream = new DataInputStream (inputStream)
73
78
serverSocket.close()
74
79
75
80
try {
@@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
155
160
} else if (deserializer == SerializationFormats .ROW ) {
156
161
dataOut.write(elem.asInstanceOf [Array [Byte ]])
157
162
} else if (deserializer == SerializationFormats .STRING ) {
163
+ // write string(for StringRRDD)
158
164
printOut.println(elem)
159
165
}
160
166
}
@@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
180
186
}.start()
181
187
}
182
188
183
- protected def openDataStream ( input : InputStream ): Closeable
189
+ protected def readData ( length : Int ): U
184
190
185
- protected def read (): U
191
+ protected def read (): U = {
192
+ try {
193
+ val length = dataStream.readInt()
194
+
195
+ length match {
196
+ case SpecialLengths .TIMING_DATA =>
197
+ // Timing data from R worker
198
+ val boot = dataStream.readDouble - bootTime
199
+ val init = dataStream.readDouble
200
+ val broadcast = dataStream.readDouble
201
+ val input = dataStream.readDouble
202
+ val compute = dataStream.readDouble
203
+ val output = dataStream.readDouble
204
+ logInfo(
205
+ (" Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
206
+ " read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
207
+ " total = %.3f s" ).format(
208
+ boot,
209
+ init,
210
+ broadcast,
211
+ input,
212
+ compute,
213
+ output,
214
+ boot + init + broadcast + input + compute + output))
215
+ read()
216
+ case length if length >= 0 =>
217
+ readData(length)
218
+ }
219
+ } catch {
220
+ case eof : EOFException =>
221
+ throw new SparkException (" R worker exited unexpectedly (cranshed)" , eof)
222
+ }
223
+ }
186
224
}
187
225
188
226
/**
@@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag](
202
240
SerializationFormats .BYTE , packageNames, rLibDir,
203
241
broadcastVars.map(x => x.asInstanceOf [Broadcast [Object ]])) {
204
242
205
- private var dataStream : DataInputStream = _
206
-
207
- override protected def openDataStream (input : InputStream ): Closeable = {
208
- dataStream = new DataInputStream (input)
209
- dataStream
210
- }
211
-
212
- override protected def read (): (Int , Array [Byte ]) = {
213
- try {
214
- val length = dataStream.readInt()
215
-
216
- length match {
217
- case length if length == 2 =>
218
- val hashedKey = dataStream.readInt()
219
- val contentPairsLength = dataStream.readInt()
220
- val contentPairs = new Array [Byte ](contentPairsLength)
221
- dataStream.readFully(contentPairs)
222
- (hashedKey, contentPairs)
223
- case _ => null // End of input
224
- }
225
- } catch {
226
- case eof : EOFException => {
227
- throw new SparkException (" R worker exited unexpectedly (crashed)" , eof)
228
- }
229
- }
243
+ override protected def readData (length : Int ): (Int , Array [Byte ]) = {
244
+ length match {
245
+ case length if length == 2 =>
246
+ val hashedKey = dataStream.readInt()
247
+ val contentPairsLength = dataStream.readInt()
248
+ val contentPairs = new Array [Byte ](contentPairsLength)
249
+ dataStream.readFully(contentPairs)
250
+ (hashedKey, contentPairs)
251
+ case _ => null
252
+ }
230
253
}
231
254
232
255
lazy val asJavaPairRDD : JavaPairRDD [Int , Array [Byte ]] = JavaPairRDD .fromRDD(this )
@@ -247,28 +270,13 @@ private class RRDD[T: ClassTag](
247
270
parent, - 1 , func, deserializer, serializer, packageNames, rLibDir,
248
271
broadcastVars.map(x => x.asInstanceOf [Broadcast [Object ]])) {
249
272
250
- private var dataStream : DataInputStream = _
251
-
252
- override protected def openDataStream (input : InputStream ): Closeable = {
253
- dataStream = new DataInputStream (input)
254
- dataStream
255
- }
256
-
257
- override protected def read (): Array [Byte ] = {
258
- try {
259
- val length = dataStream.readInt()
260
-
261
- length match {
262
- case length if length > 0 =>
263
- val obj = new Array [Byte ](length)
264
- dataStream.readFully(obj, 0 , length)
265
- obj
266
- case _ => null
267
- }
268
- } catch {
269
- case eof : EOFException => {
270
- throw new SparkException (" R worker exited unexpectedly (crashed)" , eof)
271
- }
273
+ override protected def readData (length : Int ): Array [Byte ] = {
274
+ length match {
275
+ case length if length > 0 =>
276
+ val obj = new Array [Byte ](length)
277
+ dataStream.readFully(obj)
278
+ obj
279
+ case _ => null
272
280
}
273
281
}
274
282
@@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag](
289
297
parent, - 1 , func, deserializer, SerializationFormats .STRING , packageNames, rLibDir,
290
298
broadcastVars.map(x => x.asInstanceOf [Broadcast [Object ]])) {
291
299
292
- private var dataStream : BufferedReader = _
293
-
294
- override protected def openDataStream (input : InputStream ): Closeable = {
295
- dataStream = new BufferedReader (new InputStreamReader (input))
296
- dataStream
297
- }
298
-
299
- override protected def read (): String = {
300
- try {
301
- dataStream.readLine()
302
- } catch {
303
- case e : IOException => {
304
- throw new SparkException (" R worker exited unexpectedly (crashed)" , e)
305
- }
300
+ override protected def readData (length : Int ): String = {
301
+ length match {
302
+ case length if length > 0 =>
303
+ SerDe .readStringBytes(dataStream, length)
304
+ case _ => null
306
305
}
307
306
}
308
307
309
308
lazy val asJavaRDD : JavaRDD [String ] = JavaRDD .fromRDD(this )
310
309
}
311
310
311
+ private object SpecialLengths {
312
+ val TIMING_DATA = - 1
313
+ }
314
+
312
315
private [r] class BufferedStreamThread (
313
316
in : InputStream ,
314
317
name : String ,
0 commit comments