17
17
18
18
package org .apache .spark .sql
19
19
20
+ import java .io .ByteArrayOutputStream
21
+ import java .nio .channels .Channels
22
+
20
23
import scala .collection .JavaConverters ._
21
- import scala .language .implicitConversions
22
24
23
25
import io .netty .buffer .ArrowBuf
24
26
import org .apache .arrow .memory .{BaseAllocator , RootAllocator }
25
27
import org .apache .arrow .vector ._
26
28
import org .apache .arrow .vector .BaseValueVector .BaseMutator
29
+ import org .apache .arrow .vector .file .ArrowWriter
27
30
import org .apache .arrow .vector .schema .{ArrowFieldNode , ArrowRecordBatch }
28
31
import org .apache .arrow .vector .types .{FloatingPointPrecision , TimeUnit }
29
32
import org .apache .arrow .vector .types .pojo .{ArrowType , Field , Schema }
30
33
31
34
import org .apache .spark .sql .catalyst .InternalRow
32
35
import org .apache .spark .sql .types ._
33
36
34
- object Arrow {
37
+ /**
38
+ * Intermediate data structure returned from Arrow conversions
39
+ */
40
+ private [sql] abstract class ArrowPayload extends Iterator [ArrowRecordBatch ]
41
+
42
+ /**
43
+ * Class that wraps an Arrow RootAllocator used in conversion
44
+ */
45
+ private [sql] class ArrowConverters {
46
+ private val _allocator = new RootAllocator (Long .MaxValue )
47
+
48
+ private [sql] def allocator : RootAllocator = _allocator
49
+
50
+ private class ArrowStaticPayload (batches : ArrowRecordBatch * ) extends ArrowPayload {
51
+ private val iter = batches.iterator
52
+
53
+ override def next (): ArrowRecordBatch = iter.next()
54
+ override def hasNext : Boolean = iter.hasNext
55
+ }
56
+
57
+ def internalRowsToPayload (rows : Array [InternalRow ], schema : StructType ): ArrowPayload = {
58
+ val batch = ArrowConverters .internalRowsToArrowRecordBatch(rows, schema, allocator)
59
+ new ArrowStaticPayload (batch)
60
+ }
61
+ }
62
+
63
+ private [sql] object ArrowConverters {
35
64
36
65
/**
37
66
* Map a Spark Dataset type to ArrowType.
@@ -49,7 +78,7 @@ object Arrow {
49
78
case BinaryType => ArrowType .Binary .INSTANCE
50
79
case DateType => ArrowType .Date .INSTANCE
51
80
case TimestampType => new ArrowType .Timestamp (TimeUnit .MILLISECOND )
52
- case _ => throw new UnsupportedOperationException (s " Unsupported data type: ${ dataType} " )
81
+ case _ => throw new UnsupportedOperationException (s " Unsupported data type: $dataType" )
53
82
}
54
83
}
55
84
@@ -109,6 +138,25 @@ object Arrow {
109
138
}
110
139
new Schema (arrowFields.toList.asJava)
111
140
}
141
+
142
+ /**
143
+ * Write an ArrowPayload to a byte array
144
+ */
145
+ private [sql] def payloadToByteArray (payload : ArrowPayload , schema : StructType ): Array [Byte ] = {
146
+ val arrowSchema = ArrowConverters .schemaToArrowSchema(schema)
147
+ val out = new ByteArrayOutputStream ()
148
+ val writer = new ArrowWriter (Channels .newChannel(out), arrowSchema)
149
+ try {
150
+ payload.foreach(writer.writeRecordBatch)
151
+ } catch {
152
+ case e : Exception =>
153
+ throw e
154
+ } finally {
155
+ writer.close()
156
+ payload.foreach(_.close())
157
+ }
158
+ out.toByteArray
159
+ }
112
160
}
113
161
114
162
private [sql] trait ColumnWriter {
@@ -255,7 +303,7 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
255
303
private [sql] class BinaryColumnWriter (allocator : BaseAllocator )
256
304
extends PrimitiveColumnWriter (allocator) {
257
305
override protected val valueVector : NullableVarBinaryVector
258
- = new NullableVarBinaryVector (" UTF8StringValue " , allocator)
306
+ = new NullableVarBinaryVector (" BinaryValue " , allocator)
259
307
override protected val valueMutator : NullableVarBinaryVector # Mutator = valueVector.getMutator
260
308
261
309
override def setNull (): Unit = valueMutator.setNull(count)
@@ -273,6 +321,7 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator)
273
321
274
322
override protected def setNull (): Unit = valueMutator.setNull(count)
275
323
override protected def setValue (row : InternalRow , ordinal : Int ): Unit = {
324
+ // TODO: comment on diff btw value representations of date/timestamp
276
325
valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000 )
277
326
}
278
327
}
@@ -286,6 +335,7 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
286
335
override protected def setNull (): Unit = valueMutator.setNull(count)
287
336
288
337
override protected def setValue (row : InternalRow , ordinal : Int ): Unit = {
338
+ // TODO: use microsecond timestamp when ARROW-477 is resolved
289
339
valueMutator.setSafe(count, row.getLong(ordinal) / 1000 )
290
340
}
291
341
}
0 commit comments