Skip to content

Commit fadf588

Browse files
committed
Changed RootAllocator param to Option in collectAsArrow
added more tests and cleanup closes apache#20
1 parent 08ef4c4 commit fadf588

File tree

5 files changed

+131
-31
lines changed

5 files changed

+131
-31
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,8 +2375,8 @@ class Dataset[T] private[sql](
23752375
* @since 2.2.0
23762376
*/
23772377
@DeveloperApi
2378-
def collectAsArrow(
2379-
allocator: RootAllocator = new RootAllocator(Long.MaxValue)): ArrowRecordBatch = {
2378+
def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = {
2379+
val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue))
23802380
withNewExecutionId {
23812381
try {
23822382
val collectedRows = queryExecution.executedPlan.executeCollect()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"schema": {
3+
"fields": [
4+
{
5+
"name": "a",
6+
"type": {"name": "int", "isSigned": true, "bitWidth": 32},
7+
"nullable": true,
8+
"children": [],
9+
"typeLayout": {
10+
"vectors": [
11+
{"type": "VALIDITY", "typeBitWidth": 1},
12+
{"type": "DATA", "typeBitWidth": 32}
13+
]
14+
}
15+
}
16+
]
17+
},
18+
19+
"batches": [
20+
{
21+
"count": 4,
22+
"columns": [
23+
{
24+
"name": "a",
25+
"count": 4,
26+
"VALIDITY": [0, 0, 0, 0],
27+
"DATA": [0, 0, 0, 0]
28+
}
29+
]
30+
}
31+
]
32+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
{
2+
"schema": {
3+
"fields": [
4+
{
5+
"name": "i",
6+
"type": {"name": "int", "isSigned": true, "bitWidth": 32},
7+
"nullable": false,
8+
"children": [],
9+
"typeLayout": {
10+
"vectors": [
11+
{"type": "VALIDITY", "typeBitWidth": 1},
12+
{"type": "DATA", "typeBitWidth": 8}
13+
]
14+
}
15+
},
16+
{
17+
"name": "NaN_f",
18+
"type": {"name": "floatingpoint", "precision": "SINGLE"},
19+
"nullable": false,
20+
"children": [],
21+
"typeLayout": {
22+
"vectors": [
23+
{"type": "VALIDITY", "typeBitWidth": 1},
24+
{"type": "DATA", "typeBitWidth": 32}
25+
]
26+
}
27+
},
28+
{
29+
"name": "NaN_d",
30+
"type": {"name": "floatingpoint", "precision": "DOUBLE"},
31+
"nullable": false,
32+
"children": [],
33+
"typeLayout": {
34+
"vectors": [
35+
{"type": "VALIDITY", "typeBitWidth": 1},
36+
{"type": "DATA", "typeBitWidth": 32}
37+
]
38+
}
39+
}
40+
]
41+
},
42+
43+
"batches": [
44+
{
45+
"count": 2,
46+
"columns": [
47+
{
48+
"name": "i",
49+
"count": 2,
50+
"VALIDITY": [1, 1],
51+
"DATA": [1, 2]
52+
},
53+
{
54+
"name": "NaN_f",
55+
"count": 2,
56+
"VALIDITY": [1, 1],
57+
"DATA": [1.2, "NaN"]
58+
},
59+
{
60+
"name": "NaN_d",
61+
"count": 2,
62+
"VALIDITY": [1, 1],
63+
"DATA": ["NaN", 1.23]
64+
}
65+
]
66+
}
67+
]
68+
}

sql/core/src/test/resources/test-data/arrow/timestampData.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"schema": {
33
"fields": [
44
{
5-
"name": "a_timestamp",
5+
"name": "c_timestamp",
66
"type": {"name": "timestamp", "unit": "MILLISECOND"},
77
"nullable": true,
88
"children": [],
@@ -21,7 +21,7 @@
2121
"count": 2,
2222
"columns": [
2323
{
24-
"name": "a_timestamp",
24+
"name": "c_timestamp",
2525
"count": 2,
2626
"VALIDITY": [1, 1],
2727
"DATA": [1365383415567, 1365426610789]

sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.file.json.JsonFileReader
2727
import org.apache.arrow.vector.util.Validator
2828

2929
import org.apache.spark.sql.test.SharedSQLContext
30+
import org.apache.spark.unsafe.types.CalendarInterval
3031

3132

3233
// NOTE - nullable type can be declared as Option[*] or java.lang.*
@@ -88,25 +89,16 @@ class ArrowSuite extends SharedSQLContext {
8889
test("string type conversion") {
8990
collectAndValidate(upperCaseData, "test-data/arrow/uppercase-strings.json")
9091
collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json")
92+
val nullStringsColOnly = nullStrings.select(nullStrings.columns(1))
93+
collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json")
9194
}
9295

9396
ignore("date conversion") {
94-
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
95-
val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
96-
val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
97-
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime)
98-
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime)
99-
val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2))
100-
.toDF("a_date", "b_string", "c_timestamp")
10197
collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json")
10298
}
10399

104100
test("timestamp conversion") {
105-
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US)
106-
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime)
107-
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime)
108-
val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp")
109-
collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json")
101+
collectAndValidate(dateTimeData.select($"c_timestamp"), "test-data/arrow/timestampData.json")
110102
}
111103

112104
// Arrow json reader doesn't support binary data
@@ -120,24 +112,15 @@ class ArrowSuite extends SharedSQLContext {
120112

121113
test("mapped type conversion") { }
122114

123-
test("other type conversion") {
124-
// half-precision
125-
// byte type, or binary
126-
// allNulls
115+
test("floating-point NaN") {
116+
val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d")
117+
collectAndValidate(nanData, "test-data/arrow/nanData-floating_point.json")
127118
}
128119

129-
test("floating-point NaN") { }
130-
131-
test("other null conversion") { }
132-
133120
test("convert int column with null to arrow") {
134121
collectAndValidate(nullInts, "test-data/arrow/null-ints.json")
135122
collectAndValidate(testData3, "test-data/arrow/null-ints-mixed.json")
136-
}
137-
138-
test("convert string column with null to arrow") {
139-
val nullStringsColOnly = nullStrings.select(nullStrings.columns(1))
140-
collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json")
123+
collectAndValidate(allNulls, "test-data/arrow/allNulls-ints.json")
141124
}
142125

143126
test("empty frame collect") {
@@ -146,7 +129,14 @@ class ArrowSuite extends SharedSQLContext {
146129
}
147130

148131
test("unsupported types") {
149-
intercept[UnsupportedOperationException] {
132+
def runUnsupported(block: => Unit): Unit = {
133+
val msg = intercept[UnsupportedOperationException] {
134+
block
135+
}
136+
assert(msg.getMessage.contains("Unsupported data type"))
137+
}
138+
139+
runUnsupported {
150140
collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json")
151141
}
152142
}
@@ -180,7 +170,7 @@ class ArrowSuite extends SharedSQLContext {
180170
val jsonSchema = jsonReader.start()
181171
Validator.compareSchemas(arrowSchema, jsonSchema)
182172

183-
val arrowRecordBatch = df.collectAsArrow(allocator)
173+
val arrowRecordBatch = df.collectAsArrow(Some(allocator))
184174
val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator)
185175
val vectorLoader = new VectorLoader(arrowRoot)
186176
vectorLoader.load(arrowRecordBatch)
@@ -240,4 +230,14 @@ class ArrowSuite extends SharedSQLContext {
240230
DoubleData(5, 0.0001, None) ::
241231
DoubleData(6, 20000.0, Some(3.3)) :: Nil).toDF()
242232
}
233+
234+
protected lazy val dateTimeData: DataFrame = {
235+
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US)
236+
val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime)
237+
val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime)
238+
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime)
239+
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime)
240+
Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2))
241+
.toDF("a_date", "b_string", "c_timestamp")
242+
}
243243
}

0 commit comments

Comments
 (0)