Skip to content

Commit b55499a

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools
We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called. Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format. In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (apache#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used. This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array. Author: Josh Rosen <[email protected]> Closes apache#7306 from JoshRosen/SPARK-8932 and squashes the following commits: 338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools.
1 parent a290814 commit b55499a

File tree

4 files changed

+87
-19
lines changed

4 files changed

+87
-19
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap(
120120
this.bufferPool = new ObjectPool(initialCapacity);
121121

122122
InternalRow initRow = initProjection.apply(emptyRow);
123-
this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
123+
int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
124+
this.emptyBuffer = new byte[emptyBufferSize];
124125
int writtenLength = bufferConverter.writeRow(
125-
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
126+
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
127+
bufferPool);
126128
assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
127129
// re-use the empty buffer only when there is no object saved in pool.
128130
reuseEmptyBuffer = bufferPool.size() == 0;
@@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
142144
groupingKey,
143145
groupingKeyConversionScratchSpace,
144146
PlatformDependent.BYTE_ARRAY_OFFSET,
147+
groupingKeySize,
145148
keyPool);
146149
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
147150

@@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
157160
// There is some objects referenced by emptyBuffer, so generate a new one
158161
InternalRow initRow = initProjection.apply(emptyRow);
159162
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
160-
bufferPool);
163+
groupingKeySize, bufferPool);
161164
}
162165
loc.putNewKey(
163166
groupingKeyConversionScratchSpace,
@@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
175178
address.getBaseObject(),
176179
address.getBaseOffset(),
177180
bufferConverter.numFields(),
181+
loc.getValueLength(),
178182
bufferPool
179183
);
180184
return currentBuffer;
@@ -214,12 +218,14 @@ public MapEntry next() {
214218
keyAddress.getBaseObject(),
215219
keyAddress.getBaseOffset(),
216220
keyConverter.numFields(),
221+
loc.getKeyLength(),
217222
keyPool
218223
);
219224
entry.value.pointTo(
220225
valueAddress.getBaseObject(),
221226
valueAddress.getBaseOffset(),
222227
bufferConverter.numFields(),
228+
loc.getValueLength(),
223229
bufferPool
224230
);
225231
return entry;

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow {
6868
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
6969
private int numFields;
7070

71+
/** The size of this row's backing data, in bytes) */
72+
private int sizeInBytes;
73+
7174
public int length() { return numFields; }
7275

7376
/** The width of the null tracking bit set, in bytes */
@@ -95,14 +98,17 @@ public UnsafeRow() { }
9598
* @param baseObject the base object
9699
* @param baseOffset the offset within the base object
97100
* @param numFields the number of fields in this row
101+
* @param sizeInBytes the size of this row's backing data, in bytes
98102
* @param pool the object pool to hold arbitrary objects
99103
*/
100-
public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
104+
public void pointTo(
105+
Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
101106
assert numFields >= 0 : "numFields should >= 0";
102107
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
103108
this.baseObject = baseObject;
104109
this.baseOffset = baseOffset;
105110
this.numFields = numFields;
111+
this.sizeInBytes = sizeInBytes;
106112
this.pool = pool;
107113
}
108114

@@ -336,9 +342,31 @@ public double getDouble(int i) {
336342
}
337343
}
338344

345+
/**
346+
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
347+
* byte array rather than referencing data stored in a data page.
348+
* <p>
349+
* This method is only supported on UnsafeRows that do not use ObjectPools.
350+
*/
339351
@Override
340352
public InternalRow copy() {
341-
throw new UnsupportedOperationException();
353+
if (pool != null) {
354+
throw new UnsupportedOperationException(
355+
"Copy is not supported for UnsafeRows that use object pools");
356+
} else {
357+
UnsafeRow rowCopy = new UnsafeRow();
358+
final byte[] rowDataCopy = new byte[sizeInBytes];
359+
PlatformDependent.copyMemory(
360+
baseObject,
361+
baseOffset,
362+
rowDataCopy,
363+
PlatformDependent.BYTE_ARRAY_OFFSET,
364+
sizeInBytes
365+
);
366+
rowCopy.pointTo(
367+
rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
368+
return rowCopy;
369+
}
342370
}
343371

344372
@Override

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
7070
* @param row the row to convert
7171
* @param baseObject the base object of the destination address
7272
* @param baseOffset the base offset of the destination address
73+
* @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
7374
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
7475
*/
75-
def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
76-
unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
76+
def writeRow(
77+
row: InternalRow,
78+
baseObject: Object,
79+
baseOffset: Long,
80+
rowLengthInBytes: Int,
81+
pool: ObjectPool): Int = {
82+
unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool)
7783

7884
if (writers.length > 0) {
7985
// zero-out the bitset

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
4444
val sizeRequired: Int = converter.getSizeRequirement(row)
4545
assert(sizeRequired === 8 + (3 * 8))
4646
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
47-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
47+
val numBytesWritten =
48+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
4849
assert(numBytesWritten === sizeRequired)
4950

5051
val unsafeRow = new UnsafeRow()
51-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
52+
unsafeRow.pointTo(
53+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
5254
assert(unsafeRow.getLong(0) === 0)
5355
assert(unsafeRow.getLong(1) === 1)
5456
assert(unsafeRow.getInt(2) === 2)
5557

58+
// We can copy UnsafeRows as long as they don't reference ObjectPools
59+
val unsafeRowCopy = unsafeRow.copy()
60+
assert(unsafeRowCopy.getLong(0) === 0)
61+
assert(unsafeRowCopy.getLong(1) === 1)
62+
assert(unsafeRowCopy.getInt(2) === 2)
63+
5664
unsafeRow.setLong(1, 3)
5765
assert(unsafeRow.getLong(1) === 3)
5866
unsafeRow.setInt(2, 4)
5967
assert(unsafeRow.getInt(2) === 4)
68+
69+
// Mutating the original row should not have changed the copy
70+
assert(unsafeRowCopy.getLong(0) === 0)
71+
assert(unsafeRowCopy.getLong(1) === 1)
72+
assert(unsafeRowCopy.getInt(2) === 2)
6073
}
6174

6275
test("basic conversion with primitive, string and binary types") {
@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
7386
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
7487
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
7588
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
76-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
89+
val numBytesWritten = converter.writeRow(
90+
row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
7791
assert(numBytesWritten === sizeRequired)
7892

7993
val unsafeRow = new UnsafeRow()
8094
val pool = new ObjectPool(10)
81-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
95+
unsafeRow.pointTo(
96+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
8297
assert(unsafeRow.getLong(0) === 0)
8398
assert(unsafeRow.getString(1) === "Hello")
8499
assert(unsafeRow.get(2) === "World".getBytes)
@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
96111
unsafeRow.update(2, "Hello World".getBytes)
97112
assert(unsafeRow.get(2) === "Hello World".getBytes)
98113
assert(pool.size === 2)
114+
115+
// We do not support copy() for UnsafeRows that reference ObjectPools
116+
intercept[UnsupportedOperationException] {
117+
unsafeRow.copy()
118+
}
99119
}
100120

101121
test("basic conversion with primitive, decimal and array") {
@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
111131
val sizeRequired: Int = converter.getSizeRequirement(row)
112132
assert(sizeRequired === 8 + (8 * 3))
113133
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
114-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
134+
val numBytesWritten =
135+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool)
115136
assert(numBytesWritten === sizeRequired)
116137
assert(pool.size === 2)
117138

118139
val unsafeRow = new UnsafeRow()
119-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
140+
unsafeRow.pointTo(
141+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
120142
assert(unsafeRow.getLong(0) === 0)
121143
assert(unsafeRow.get(1) === Decimal(1))
122144
assert(unsafeRow.get(2) === Array(2))
@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
142164
assert(sizeRequired === 8 + (8 * 4) +
143165
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
144166
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
145-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
167+
val numBytesWritten =
168+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
146169
assert(numBytesWritten === sizeRequired)
147170

148171
val unsafeRow = new UnsafeRow()
149-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
172+
unsafeRow.pointTo(
173+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
150174
assert(unsafeRow.getLong(0) === 0)
151175
assert(unsafeRow.getString(1) === "Hello")
152176
// Date is represented as Int in unsafeRow
@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
190214
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
191215
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
192216
val numBytesWritten = converter.writeRow(
193-
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
217+
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
218+
sizeRequired, null)
194219
assert(numBytesWritten === sizeRequired)
195220

196221
val createdFromNull = new UnsafeRow()
197222
createdFromNull.pointTo(
198-
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
223+
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
224+
sizeRequired, null)
199225
for (i <- 0 to fieldTypes.length - 1) {
200226
assert(createdFromNull.isNullAt(i))
201227
}
@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
233259
val pool = new ObjectPool(1)
234260
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
235261
converter.writeRow(
236-
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
262+
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
263+
sizeRequired, pool)
237264
val setToNullAfterCreation = new UnsafeRow()
238265
setToNullAfterCreation.pointTo(
239-
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
266+
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
267+
sizeRequired, pool)
240268

241269
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
242270
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))

0 commit comments

Comments
 (0)