Skip to content

Commit 102cf3f

Browse files
committed
renamed to ArrowConverters
defined ArrowPayload and encapsulated Arrow classes in ArrowConverters addressed some minor comments in code review closes apache#21
1 parent fadf588 commit 102cf3f

File tree

3 files changed

+77
-44
lines changed

3 files changed

+77
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala renamed to sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,50 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.io.ByteArrayOutputStream
21+
import java.nio.channels.Channels
22+
2023
import scala.collection.JavaConverters._
21-
import scala.language.implicitConversions
2224

2325
import io.netty.buffer.ArrowBuf
2426
import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
2527
import org.apache.arrow.vector._
2628
import org.apache.arrow.vector.BaseValueVector.BaseMutator
29+
import org.apache.arrow.vector.file.ArrowWriter
2730
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
2831
import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit}
2932
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
3033

3134
import org.apache.spark.sql.catalyst.InternalRow
3235
import org.apache.spark.sql.types._
3336

34-
object Arrow {
37+
/**
38+
* Intermediate data structure returned from Arrow conversions
39+
*/
40+
private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch]
41+
42+
/**
43+
* Class that wraps an Arrow RootAllocator used in conversion
44+
*/
45+
private[sql] class ArrowConverters {
46+
private val _allocator = new RootAllocator(Long.MaxValue)
47+
48+
private[sql] def allocator: RootAllocator = _allocator
49+
50+
private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload {
51+
private val iter = batches.iterator
52+
53+
override def next(): ArrowRecordBatch = iter.next()
54+
override def hasNext: Boolean = iter.hasNext
55+
}
56+
57+
def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = {
58+
val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator)
59+
new ArrowStaticPayload(batch)
60+
}
61+
}
62+
63+
private[sql] object ArrowConverters {
3564

3665
/**
3766
* Map a Spark Dataset type to ArrowType.
@@ -49,7 +78,7 @@ object Arrow {
4978
case BinaryType => ArrowType.Binary.INSTANCE
5079
case DateType => ArrowType.Date.INSTANCE
5180
case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND)
52-
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
81+
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
5382
}
5483
}
5584

@@ -109,6 +138,25 @@ object Arrow {
109138
}
110139
new Schema(arrowFields.toList.asJava)
111140
}
141+
142+
/**
143+
* Write an ArrowPayload to a byte array
144+
*/
145+
private[sql] def payloadToByteArray(payload: ArrowPayload, schema: StructType): Array[Byte] = {
146+
val arrowSchema = ArrowConverters.schemaToArrowSchema(schema)
147+
val out = new ByteArrayOutputStream()
148+
val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema)
149+
try {
150+
payload.foreach(writer.writeRecordBatch)
151+
} catch {
152+
case e: Exception =>
153+
throw e
154+
} finally {
155+
writer.close()
156+
payload.foreach(_.close())
157+
}
158+
out.toByteArray
159+
}
112160
}
113161

114162
private[sql] trait ColumnWriter {
@@ -255,7 +303,7 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
255303
private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
256304
extends PrimitiveColumnWriter(allocator) {
257305
override protected val valueVector: NullableVarBinaryVector
258-
= new NullableVarBinaryVector("UTF8StringValue", allocator)
306+
= new NullableVarBinaryVector("BinaryValue", allocator)
259307
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
260308

261309
override def setNull(): Unit = valueMutator.setNull(count)
@@ -273,6 +321,7 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator)
273321

274322
override protected def setNull(): Unit = valueMutator.setNull(count)
275323
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
324+
// TODO: comment on diff btw value representations of date/timestamp
276325
valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000)
277326
}
278327
}
@@ -286,6 +335,7 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
286335
override protected def setNull(): Unit = valueMutator.setNull(count)
287336

288337
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
338+
// TODO: use microsecond timestamp when ARROW-477 is resolved
289339
valueMutator.setSafe(count, row.getLong(ordinal) / 1000)
290340
}
291341
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.io.{ByteArrayOutputStream, CharArrayWriter}
21-
import java.nio.channels.Channels
20+
import java.io.CharArrayWriter
2221

2322
import scala.collection.JavaConverters._
2423
import scala.language.implicitConversions
2524
import scala.reflect.runtime.universe.TypeTag
2625
import scala.util.control.NonFatal
2726

28-
import org.apache.arrow.memory.RootAllocator
29-
import org.apache.arrow.vector.file.ArrowWriter
30-
import org.apache.arrow.vector.schema.ArrowRecordBatch
3127
import org.apache.commons.lang3.StringUtils
3228

3329
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
@@ -2375,14 +2371,12 @@ class Dataset[T] private[sql](
23752371
* @since 2.2.0
23762372
*/
23772373
@DeveloperApi
2378-
def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = {
2379-
val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue))
2374+
def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = {
2375+
val cnvtr = converter.getOrElse(new ArrowConverters)
23802376
withNewExecutionId {
23812377
try {
23822378
val collectedRows = queryExecution.executedPlan.executeCollect()
2383-
val recordBatch = Arrow.internalRowsToArrowRecordBatch(
2384-
collectedRows, this.schema, allocator)
2385-
recordBatch
2379+
cnvtr.internalRowsToPayload(collectedRows, this.schema)
23862380
} catch {
23872381
case e: Exception =>
23882382
throw e
@@ -2763,22 +2757,11 @@ class Dataset[T] private[sql](
27632757
* Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark.
27642758
*/
27652759
private[sql] def collectAsArrowToPython(): Int = {
2766-
val recordBatch = collectAsArrow()
2767-
val arrowSchema = Arrow.schemaToArrowSchema(this.schema)
2768-
val out = new ByteArrayOutputStream()
2769-
try {
2770-
val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema)
2771-
writer.writeRecordBatch(recordBatch)
2772-
writer.close()
2773-
} catch {
2774-
case e: Exception =>
2775-
throw e
2776-
} finally {
2777-
recordBatch.close()
2778-
}
2760+
val payload = collectAsArrow()
2761+
val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema)
27792762

27802763
withNewExecutionId {
2781-
PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow")
2764+
PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow")
27822765
}
27832766
}
27842767

sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala renamed to sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@ package org.apache.spark.sql
1919
import java.io.File
2020
import java.sql.{Date, Timestamp}
2121
import java.text.SimpleDateFormat
22-
import java.util.{Locale, TimeZone}
22+
import java.util.Locale
2323

24-
import org.apache.arrow.memory.RootAllocator
2524
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
2625
import org.apache.arrow.vector.file.json.JsonFileReader
2726
import org.apache.arrow.vector.util.Validator
2827

2928
import org.apache.spark.sql.test.SharedSQLContext
30-
import org.apache.spark.unsafe.types.CalendarInterval
3129

3230

3331
// NOTE - nullable type can be declared as Option[*] or java.lang.*
@@ -38,18 +36,19 @@ private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float])
3836
private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double])
3937

4038

41-
class ArrowSuite extends SharedSQLContext {
39+
class ArrowConvertersSuite extends SharedSQLContext {
4240
import testImplicits._
4341

4442
private def testFile(fileName: String): String = {
4543
Thread.currentThread().getContextClassLoader.getResource(fileName).getFile
4644
}
4745

4846
test("collect to arrow record batch") {
49-
val arrowRecordBatch = indexData.collectAsArrow()
50-
assert(arrowRecordBatch.getLength > 0)
51-
assert(arrowRecordBatch.getNodes.size() > 0)
52-
arrowRecordBatch.close()
47+
val arrowPayload = indexData.collectAsArrow()
48+
assert(arrowPayload.nonEmpty)
49+
arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0))
50+
arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0))
51+
arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close())
5352
}
5453

5554
test("standard type conversion") {
@@ -124,8 +123,9 @@ class ArrowSuite extends SharedSQLContext {
124123
}
125124

126125
test("empty frame collect") {
127-
val emptyBatch = spark.emptyDataFrame.collectAsArrow()
128-
assert(emptyBatch.getLength == 0)
126+
val arrowPayload = spark.emptyDataFrame.collectAsArrow()
127+
assert(arrowPayload.nonEmpty)
128+
arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0))
129129
}
130130

131131
test("unsupported types") {
@@ -163,17 +163,17 @@ class ArrowSuite extends SharedSQLContext {
163163
private def collectAndValidate(df: DataFrame, arrowFile: String) {
164164
val jsonFilePath = testFile(arrowFile)
165165

166-
val allocator = new RootAllocator(Integer.MAX_VALUE)
167-
val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator)
166+
val converter = new ArrowConverters
167+
val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator)
168168

169-
val arrowSchema = Arrow.schemaToArrowSchema(df.schema)
169+
val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema)
170170
val jsonSchema = jsonReader.start()
171171
Validator.compareSchemas(arrowSchema, jsonSchema)
172172

173-
val arrowRecordBatch = df.collectAsArrow(Some(allocator))
174-
val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator)
173+
val arrowPayload = df.collectAsArrow(Some(converter))
174+
val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator)
175175
val vectorLoader = new VectorLoader(arrowRoot)
176-
vectorLoader.load(arrowRecordBatch)
176+
arrowPayload.foreach(vectorLoader.load)
177177
val jsonRoot = jsonReader.read()
178178

179179
Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)

0 commit comments

Comments
 (0)