Skip to content

Commit 447dea0

Browse files
author
Davies Liu
committed
support binaryType in UnsafeRow
1 parent 54976e5 commit 447dea0

File tree

3 files changed

+61
-34
lines changed

3 files changed

+61
-34
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
9292
*/
9393
public static final Set<DataType> readableFieldTypes;
9494

95+
// TODO: support DecimalType
9596
static {
9697
settableFieldTypes = Collections.unmodifiableSet(
9798
new HashSet<DataType>(
@@ -111,7 +112,8 @@ public static int calculateBitSetWidthInBytes(int numFields) {
111112
// We support get() on a superset of the types for which we support set():
112113
final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
113114
Arrays.asList(new DataType[]{
114-
StringType
115+
StringType,
116+
BinaryType
115117
}));
116118
_readableFieldTypes.addAll(settableFieldTypes);
117119
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -221,11 +223,6 @@ public void setFloat(int ordinal, float value) {
221223
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
222224
}
223225

224-
@Override
225-
public void setString(int ordinal, String value) {
226-
throw new UnsupportedOperationException();
227-
}
228-
229226
@Override
230227
public int size() {
231228
return numFields;
@@ -249,6 +246,8 @@ public Object get(int i) {
249246
return null;
250247
} else if (dataType == StringType) {
251248
return getUTF8String(i);
249+
} else if (dataType == BinaryType) {
250+
return getBinary(i);
252251
} else {
253252
throw new UnsupportedOperationException();
254253
}
@@ -311,21 +310,23 @@ public double getDouble(int i) {
311310
}
312311

313312
public UTF8String getUTF8String(int i) {
313+
return UTF8String.fromBytes(getBinary(i));
314+
}
315+
316+
public byte[] getBinary(int i) {
314317
assertIndexIsValid(i);
315-
final UTF8String str = new UTF8String();
316-
final long offsetToStringSize = getLong(i);
317-
final int stringSizeInBytes =
318-
(int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
319-
final byte[] strBytes = new byte[stringSizeInBytes];
318+
final long offsetAndSize = getLong(i);
319+
final int offset = (int)(offsetAndSize >> 32);
320+
final int size = (int)(offsetAndSize & ((1L << 32) - 1));
321+
final byte[] bytes = new byte[size];
320322
PlatformDependent.copyMemory(
321323
baseObject,
322-
baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
323-
strBytes,
324+
baseOffset + offset,
325+
bytes,
324326
PlatformDependent.BYTE_ARRAY_OFFSET,
325-
stringSizeInBytes
327+
size
326328
);
327-
str.set(strBytes);
328-
return str;
329+
return bytes;
329330
}
330331

331332
@Override

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.util.DateUtils
21-
import org.apache.spark.sql.catalyst.InternalRow
2220
import org.apache.spark.sql.types._
2321
import org.apache.spark.unsafe.PlatformDependent
2422
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -122,6 +120,7 @@ private object UnsafeColumnWriter {
122120
case FloatType => FloatUnsafeColumnWriter
123121
case DoubleType => DoubleUnsafeColumnWriter
124122
case StringType => StringUnsafeColumnWriter
123+
case BinaryType => BinaryUnsafeColumnWriter
125124
case DateType => IntUnsafeColumnWriter
126125
case TimestampType => LongUnsafeColumnWriter
127126
case t =>
@@ -141,6 +140,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
141140
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
142141
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
143142
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
143+
private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
144144

145145
private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
146146
// Primitives don't write to the variable-length region:
@@ -238,27 +238,53 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
238238
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
239239
def getSize(source: InternalRow, column: Int): Int = {
240240
val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
241-
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
241+
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
242242
}
243243

244244
override def write(
245245
source: InternalRow,
246246
target: UnsafeRow,
247247
column: Int,
248248
appendCursor: Int): Int = {
249-
val value = source.get(column).asInstanceOf[UTF8String]
249+
val value = source.get(column).asInstanceOf[UTF8String].getBytes
250250
val baseObject = target.getBaseObject
251251
val baseOffset = target.getBaseOffset
252-
val numBytes = value.getBytes.length
253-
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
252+
val numBytes = value.length
254253
PlatformDependent.copyMemory(
255-
value.getBytes,
254+
value,
256255
PlatformDependent.BYTE_ARRAY_OFFSET,
257256
baseObject,
258-
baseOffset + appendCursor + 8,
257+
baseOffset + appendCursor,
259258
numBytes
260259
)
261-
target.setLong(column, appendCursor)
262-
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
260+
target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong)
261+
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
263262
}
264263
}
264+
265+
private class BinaryUnsafeColumnWriter private() extends UnsafeColumnWriter {
266+
def getSize(source: InternalRow, column: Int): Int = {
267+
val numBytes = source.getAs[Array[Byte]](column).length
268+
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
269+
}
270+
271+
override def write(
272+
source: InternalRow,
273+
target: UnsafeRow,
274+
column: Int,
275+
appendCursor: Int): Int = {
276+
val value = source.getAs[Array[Byte]](column)
277+
val baseObject = target.getBaseObject
278+
val baseOffset = target.getBaseOffset
279+
val numBytes = value.length
280+
PlatformDependent.copyMemory(
281+
value,
282+
PlatformDependent.BYTE_ARRAY_OFFSET,
283+
baseObject,
284+
baseOffset + appendCursor,
285+
numBytes
286+
)
287+
target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong)
288+
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
289+
}
290+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import java.util.Arrays
2323
import org.scalatest.Matchers
2424

2525
import org.apache.spark.SparkFunSuite
26-
import org.apache.spark.sql.types._
2726
import org.apache.spark.sql.catalyst.util.DateUtils
27+
import org.apache.spark.sql.types._
2828
import org.apache.spark.unsafe.PlatformDependent
2929
import org.apache.spark.unsafe.array.ByteArrayMethods
3030

@@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
5252
unsafeRow.getInt(2) should be (2)
5353
}
5454

55-
test("basic conversion with primitive and string types") {
56-
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
55+
test("basic conversion with primitive, string and binary types") {
56+
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
5757
val converter = new UnsafeRowConverter(fieldTypes)
5858

5959
val row = new SpecificMutableRow(fieldTypes)
6060
row.setLong(0, 0)
6161
row.setString(1, "Hello")
62-
row.setString(2, "World")
62+
row.update(2, "World".getBytes)
6363

6464
val sizeRequired: Int = converter.getSizeRequirement(row)
6565
sizeRequired should be (8 + (8 * 3) +
66-
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
67-
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
66+
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
67+
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
6868
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
6969
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
7070
numBytesWritten should be (sizeRequired)
@@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
7373
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
7474
unsafeRow.getLong(0) should be (0)
7575
unsafeRow.getString(1) should be ("Hello")
76-
unsafeRow.getString(2) should be ("World")
76+
unsafeRow.getBinary(2) should be ("World".getBytes)
7777
}
7878

7979
test("basic conversion with primitive, string, date and timestamp types") {
@@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
8888

8989
val sizeRequired: Int = converter.getSizeRequirement(row)
9090
sizeRequired should be (8 + (8 * 4) +
91-
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
91+
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
9292
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
9393
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
9494
numBytesWritten should be (sizeRequired)

0 commit comments

Comments
 (0)