@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
44
44
val sizeRequired : Int = converter.getSizeRequirement(row)
45
45
assert(sizeRequired === 8 + (3 * 8 ))
46
46
val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
47
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , null )
47
+ val numBytesWritten =
48
+ converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , sizeRequired, null )
48
49
assert(numBytesWritten === sizeRequired)
49
50
50
51
val unsafeRow = new UnsafeRow ()
51
- unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
52
+ unsafeRow.pointTo(
53
+ buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, sizeRequired, null )
52
54
assert(unsafeRow.getLong(0 ) === 0 )
53
55
assert(unsafeRow.getLong(1 ) === 1 )
54
56
assert(unsafeRow.getInt(2 ) === 2 )
55
57
58
+ // We can copy UnsafeRows as long as they don't reference ObjectPools
59
+ val unsafeRowCopy = unsafeRow.copy()
60
+ assert(unsafeRowCopy.getLong(0 ) === 0 )
61
+ assert(unsafeRowCopy.getLong(1 ) === 1 )
62
+ assert(unsafeRowCopy.getInt(2 ) === 2 )
63
+
56
64
unsafeRow.setLong(1 , 3 )
57
65
assert(unsafeRow.getLong(1 ) === 3 )
58
66
unsafeRow.setInt(2 , 4 )
59
67
assert(unsafeRow.getInt(2 ) === 4 )
68
+
69
+ // Mutating the original row should not have changed the copy
70
+ assert(unsafeRowCopy.getLong(0 ) === 0 )
71
+ assert(unsafeRowCopy.getLong(1 ) === 1 )
72
+ assert(unsafeRowCopy.getInt(2 ) === 2 )
60
73
}
61
74
62
75
test(" basic conversion with primitive, string and binary types" ) {
@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
73
86
ByteArrayMethods .roundNumberOfBytesToNearestWord(" Hello" .getBytes.length) +
74
87
ByteArrayMethods .roundNumberOfBytesToNearestWord(" World" .getBytes.length))
75
88
val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
76
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , null )
89
+ val numBytesWritten = converter.writeRow(
90
+ row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , sizeRequired, null )
77
91
assert(numBytesWritten === sizeRequired)
78
92
79
93
val unsafeRow = new UnsafeRow ()
80
94
val pool = new ObjectPool (10 )
81
- unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, pool)
95
+ unsafeRow.pointTo(
96
+ buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, sizeRequired, pool)
82
97
assert(unsafeRow.getLong(0 ) === 0 )
83
98
assert(unsafeRow.getString(1 ) === " Hello" )
84
99
assert(unsafeRow.get(2 ) === " World" .getBytes)
@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
96
111
unsafeRow.update(2 , " Hello World" .getBytes)
97
112
assert(unsafeRow.get(2 ) === " Hello World" .getBytes)
98
113
assert(pool.size === 2 )
114
+
115
+ // We do not support copy() for UnsafeRows that reference ObjectPools
116
+ intercept[UnsupportedOperationException ] {
117
+ unsafeRow.copy()
118
+ }
99
119
}
100
120
101
121
test(" basic conversion with primitive, decimal and array" ) {
@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
111
131
val sizeRequired : Int = converter.getSizeRequirement(row)
112
132
assert(sizeRequired === 8 + (8 * 3 ))
113
133
val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
114
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , pool)
134
+ val numBytesWritten =
135
+ converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , sizeRequired, pool)
115
136
assert(numBytesWritten === sizeRequired)
116
137
assert(pool.size === 2 )
117
138
118
139
val unsafeRow = new UnsafeRow ()
119
- unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, pool)
140
+ unsafeRow.pointTo(
141
+ buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, sizeRequired, pool)
120
142
assert(unsafeRow.getLong(0 ) === 0 )
121
143
assert(unsafeRow.get(1 ) === Decimal (1 ))
122
144
assert(unsafeRow.get(2 ) === Array (2 ))
@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
142
164
assert(sizeRequired === 8 + (8 * 4 ) +
143
165
ByteArrayMethods .roundNumberOfBytesToNearestWord(" Hello" .getBytes.length))
144
166
val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
145
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , null )
167
+ val numBytesWritten =
168
+ converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET , sizeRequired, null )
146
169
assert(numBytesWritten === sizeRequired)
147
170
148
171
val unsafeRow = new UnsafeRow ()
149
- unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
172
+ unsafeRow.pointTo(
173
+ buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, sizeRequired, null )
150
174
assert(unsafeRow.getLong(0 ) === 0 )
151
175
assert(unsafeRow.getString(1 ) === " Hello" )
152
176
// Date is represented as Int in unsafeRow
@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
190
214
val sizeRequired : Int = converter.getSizeRequirement(rowWithAllNullColumns)
191
215
val createdFromNullBuffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
192
216
val numBytesWritten = converter.writeRow(
193
- rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET , null )
217
+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET ,
218
+ sizeRequired, null )
194
219
assert(numBytesWritten === sizeRequired)
195
220
196
221
val createdFromNull = new UnsafeRow ()
197
222
createdFromNull.pointTo(
198
- createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
223
+ createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length,
224
+ sizeRequired, null )
199
225
for (i <- 0 to fieldTypes.length - 1 ) {
200
226
assert(createdFromNull.isNullAt(i))
201
227
}
@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
233
259
val pool = new ObjectPool (1 )
234
260
val setToNullAfterCreationBuffer : Array [Long ] = new Array [Long ](sizeRequired / 8 + 2 )
235
261
converter.writeRow(
236
- rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET , pool)
262
+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET ,
263
+ sizeRequired, pool)
237
264
val setToNullAfterCreation = new UnsafeRow ()
238
265
setToNullAfterCreation.pointTo(
239
- setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, pool)
266
+ setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length,
267
+ sizeRequired, pool)
240
268
241
269
assert(setToNullAfterCreation.isNullAt(0 ) === rowWithNoNullColumns.isNullAt(0 ))
242
270
assert(setToNullAfterCreation.getBoolean(1 ) === rowWithNoNullColumns.getBoolean(1 ))
0 commit comments