Skip to content

Commit 4db713d

Browse files
committed
Adds in-memory column type for fixed-precision decimals
1 parent e06c7df commit 4db713d

File tree

11 files changed

+163
-76
lines changed

11 files changed

+163
-76
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
2121

2222
import org.apache.spark.sql.catalyst.expressions.MutableRow
2323
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
24-
import org.apache.spark.sql.types.{BinaryType, DataType, NativeType}
24+
import org.apache.spark.sql.types._
2525

2626
/**
2727
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
8989
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
9090
extends NativeColumnAccessor(buffer, FLOAT)
9191

92+
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
93+
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
94+
9295
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
9396
extends NativeColumnAccessor(buffer, STRING)
9497

@@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
107110
with NullableColumnAccessor
108111

109112
private[sql] object ColumnAccessor {
110-
def apply(buffer: ByteBuffer): ColumnAccessor = {
113+
def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
111114
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
112-
// The first 4 bytes in the buffer indicate the column type.
113-
val columnTypeId = dup.getInt()
114-
115-
columnTypeId match {
116-
case INT.typeId => new IntColumnAccessor(dup)
117-
case LONG.typeId => new LongColumnAccessor(dup)
118-
case FLOAT.typeId => new FloatColumnAccessor(dup)
119-
case DOUBLE.typeId => new DoubleColumnAccessor(dup)
120-
case BOOLEAN.typeId => new BooleanColumnAccessor(dup)
121-
case BYTE.typeId => new ByteColumnAccessor(dup)
122-
case SHORT.typeId => new ShortColumnAccessor(dup)
123-
case STRING.typeId => new StringColumnAccessor(dup)
124-
case DATE.typeId => new DateColumnAccessor(dup)
125-
case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
126-
case BINARY.typeId => new BinaryColumnAccessor(dup)
127-
case GENERIC.typeId => new GenericColumnAccessor(dup)
115+
116+
// The first 4 bytes in the buffer indicate the column type. This field is not used now,
117+
// because we always know the data type of the column ahead of time.
118+
dup.getInt()
119+
120+
dataType match {
121+
case IntegerType => new IntColumnAccessor(dup)
122+
case LongType => new LongColumnAccessor(dup)
123+
case FloatType => new FloatColumnAccessor(dup)
124+
case DoubleType => new DoubleColumnAccessor(dup)
125+
case BooleanType => new BooleanColumnAccessor(dup)
126+
case ByteType => new ByteColumnAccessor(dup)
127+
case ShortType => new ShortColumnAccessor(dup)
128+
case StringType => new StringColumnAccessor(dup)
129+
case BinaryType => new BinaryColumnAccessor(dup)
130+
case DateType => new DateColumnAccessor(dup)
131+
case TimestampType => new TimestampColumnAccessor(dup)
132+
case DecimalType.Fixed(precision, scale) =>
133+
new FixedDecimalColumnAccessor(dup, precision, scale)
134+
case _ => new GenericColumnAccessor(dup)
128135
}
129136
}
130137
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleCol
106106

107107
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
108108

109+
private[sql] class FixedDecimalColumnBuilder(
110+
precision: Int,
111+
scale: Int)
112+
extends NativeColumnBuilder(
113+
new FixedDecimalColumnStats,
114+
FIXED_DECIMAL(precision, scale))
115+
109116
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
110117

111118
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
@@ -139,25 +146,25 @@ private[sql] object ColumnBuilder {
139146
}
140147

141148
def apply(
142-
typeId: Int,
149+
dataType: DataType,
143150
initialSize: Int = 0,
144151
columnName: String = "",
145152
useCompression: Boolean = false): ColumnBuilder = {
146-
147-
val builder = (typeId match {
148-
case INT.typeId => new IntColumnBuilder
149-
case LONG.typeId => new LongColumnBuilder
150-
case FLOAT.typeId => new FloatColumnBuilder
151-
case DOUBLE.typeId => new DoubleColumnBuilder
152-
case BOOLEAN.typeId => new BooleanColumnBuilder
153-
case BYTE.typeId => new ByteColumnBuilder
154-
case SHORT.typeId => new ShortColumnBuilder
155-
case STRING.typeId => new StringColumnBuilder
156-
case BINARY.typeId => new BinaryColumnBuilder
157-
case GENERIC.typeId => new GenericColumnBuilder
158-
case DATE.typeId => new DateColumnBuilder
159-
case TIMESTAMP.typeId => new TimestampColumnBuilder
160-
}).asInstanceOf[ColumnBuilder]
153+
val builder: ColumnBuilder = dataType match {
154+
case IntegerType => new IntColumnBuilder
155+
case LongType => new LongColumnBuilder
156+
case DoubleType => new DoubleColumnBuilder
157+
case BooleanType => new BooleanColumnBuilder
158+
case ByteType => new ByteColumnBuilder
159+
case ShortType => new ShortColumnBuilder
160+
case StringType => new StringColumnBuilder
161+
case BinaryType => new BinaryColumnBuilder
162+
case DateType => new DateColumnBuilder
163+
case TimestampType => new TimestampColumnBuilder
164+
case DecimalType.Fixed(precision, scale) =>
165+
new FixedDecimalColumnBuilder(precision, scale)
166+
case _ => new GenericColumnBuilder
167+
}
161168

162169
builder.initialize(initialSize, columnName, useCompression)
163170
builder

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,23 @@ private[sql] class FloatColumnStats extends ColumnStats {
181181
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
182182
}
183183

184+
private[sql] class FixedDecimalColumnStats extends ColumnStats {
185+
protected var upper: Decimal = null
186+
protected var lower: Decimal = null
187+
188+
override def gatherStats(row: Row, ordinal: Int): Unit = {
189+
super.gatherStats(row, ordinal)
190+
if (!row.isNullAt(ordinal)) {
191+
val value = row(ordinal).asInstanceOf[Decimal]
192+
if (upper == null || value.compareTo(upper) > 0) upper = value
193+
if (lower == null || value.compareTo(lower) < 0) lower = value
194+
sizeInBytes += FIXED_DECIMAL.defaultSize
195+
}
196+
}
197+
198+
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
199+
}
200+
184201
private[sql] class IntColumnStats extends ColumnStats {
185202
protected var upper = Int.MinValue
186203
protected var lower = Int.MaxValue

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,33 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
373373
}
374374
}
375375

376+
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
377+
extends NativeColumnType(
378+
DecimalType(Some(PrecisionInfo(precision, scale))),
379+
10,
380+
FIXED_DECIMAL.defaultSize) {
381+
382+
override def extract(buffer: ByteBuffer): Decimal = {
383+
Decimal(buffer.getLong(), precision, scale)
384+
}
385+
386+
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
387+
buffer.putLong(v.toUnscaledLong)
388+
}
389+
390+
override def getField(row: Row, ordinal: Int): Decimal = {
391+
row(ordinal).asInstanceOf[Decimal]
392+
}
393+
394+
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
395+
row(ordinal) = value
396+
}
397+
}
398+
399+
private[sql] object FIXED_DECIMAL {
400+
val defaultSize = 8
401+
}
402+
376403
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
377404
typeId: Int,
378405
defaultSize: Int)
@@ -394,7 +421,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
394421
}
395422
}
396423

397-
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) {
424+
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) {
398425
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
399426
row(ordinal) = value
400427
}
@@ -405,7 +432,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16)
405432
// Used to process generic objects (all types other than those listed above). Objects should be
406433
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
407434
// byte array.
408-
private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
435+
private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
409436
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
410437
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
411438
}
@@ -416,18 +443,19 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
416443
private[sql] object ColumnType {
417444
def apply(dataType: DataType): ColumnType[_, _] = {
418445
dataType match {
419-
case IntegerType => INT
420-
case LongType => LONG
421-
case FloatType => FLOAT
422-
case DoubleType => DOUBLE
423-
case BooleanType => BOOLEAN
424-
case ByteType => BYTE
425-
case ShortType => SHORT
426-
case StringType => STRING
427-
case BinaryType => BINARY
428-
case DateType => DATE
446+
case IntegerType => INT
447+
case LongType => LONG
448+
case FloatType => FLOAT
449+
case DoubleType => DOUBLE
450+
case BooleanType => BOOLEAN
451+
case ByteType => BYTE
452+
case ShortType => SHORT
453+
case StringType => STRING
454+
case BinaryType => BINARY
455+
case DateType => DATE
429456
case TimestampType => TIMESTAMP
430-
case _ => GENERIC
457+
case DecimalType.Fixed(precision, scale) => FIXED_DECIMAL(precision, scale)
458+
case _ => GENERIC
431459
}
432460
}
433461
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ private[sql] case class InMemoryRelation(
113113
val columnBuilders = output.map { attribute =>
114114
val columnType = ColumnType(attribute.dataType)
115115
val initialBufferSize = columnType.defaultSize * batchSize
116-
ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
116+
ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression)
117117
}.toArray
118118

119119
var rowCount = 0
@@ -274,8 +274,10 @@ private[sql] case class InMemoryColumnarTableScan(
274274
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
275275
val rows = cacheBatches.flatMap { cachedBatch =>
276276
// Build column accessors
277-
val columnAccessors = requestedColumnIndices.map { batch =>
278-
ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
277+
val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
278+
ColumnAccessor(
279+
relation.output(batchColumnIndex).dataType,
280+
ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
279281
}
280282

281283
// Extract rows via column accessors

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ColumnStatsSuite extends FunSuite {
2929
testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
3030
testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
3131
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
32+
testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(20, 10), Row(null, null, 0))
3233
testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
3334
testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
3435
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
3333

3434
test("defaultSize") {
3535
val checks = Map(
36-
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1,
37-
STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
36+
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
37+
FIXED_DECIMAL(20, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12,
38+
BINARY -> 16, GENERIC -> 16)
3839

3940
checks.foreach { case (columnType, expectedSize) =>
4041
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -56,15 +57,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
5657
}
5758
}
5859

59-
checkActualSize(INT, Int.MaxValue, 4)
60-
checkActualSize(SHORT, Short.MaxValue, 2)
61-
checkActualSize(LONG, Long.MaxValue, 8)
62-
checkActualSize(BYTE, Byte.MaxValue, 1)
63-
checkActualSize(DOUBLE, Double.MaxValue, 8)
64-
checkActualSize(FLOAT, Float.MaxValue, 4)
65-
checkActualSize(BOOLEAN, true, 1)
66-
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
67-
checkActualSize(DATE, 0, 4)
60+
checkActualSize(INT, Int.MaxValue, 4)
61+
checkActualSize(SHORT, Short.MaxValue, 2)
62+
checkActualSize(LONG, Long.MaxValue, 8)
63+
checkActualSize(BYTE, Byte.MaxValue, 1)
64+
checkActualSize(DOUBLE, Double.MaxValue, 8)
65+
checkActualSize(FLOAT, Float.MaxValue, 4)
66+
checkActualSize(FIXED_DECIMAL(20, 10), Decimal(0, 20, 10), 8)
67+
checkActualSize(BOOLEAN, true, 1)
68+
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
69+
checkActualSize(DATE, 0, 4)
6870
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
6971

7072
val binary = Array.fill[Byte](4)(0: Byte)
@@ -93,12 +95,20 @@ class ColumnTypeSuite extends FunSuite with Logging {
9395

9496
testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
9597

98+
testNativeColumnType[DecimalType](
99+
FIXED_DECIMAL(20, 10),
100+
(buffer: ByteBuffer, decimal: Decimal) => {
101+
buffer.putLong(decimal.toUnscaledLong)
102+
},
103+
(buffer: ByteBuffer) => {
104+
Decimal(buffer.getLong(), 20, 10)
105+
})
106+
96107
testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
97108

98109
testNativeColumnType[StringType.type](
99110
STRING,
100111
(buffer: ByteBuffer, string: String) => {
101-
102112
val bytes = string.getBytes("utf-8")
103113
buffer.putInt(bytes.length)
104114
buffer.put(bytes)

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.util.Random
2424

2525
import org.apache.spark.sql.Row
2626
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
27-
import org.apache.spark.sql.types.{DataType, NativeType}
27+
import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
2828

2929
object ColumnarTestUtils {
3030
def makeNullRow(length: Int) = {
@@ -41,16 +41,17 @@ object ColumnarTestUtils {
4141
}
4242

4343
(columnType match {
44-
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
45-
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
46-
case INT => Random.nextInt()
47-
case LONG => Random.nextLong()
48-
case FLOAT => Random.nextFloat()
49-
case DOUBLE => Random.nextDouble()
50-
case STRING => Random.nextString(Random.nextInt(32))
51-
case BOOLEAN => Random.nextBoolean()
52-
case BINARY => randomBytes(Random.nextInt(32))
53-
case DATE => Random.nextInt()
44+
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
45+
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
46+
case INT => Random.nextInt()
47+
case LONG => Random.nextLong()
48+
case FLOAT => Random.nextFloat()
49+
case DOUBLE => Random.nextDouble()
50+
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
51+
case STRING => Random.nextString(Random.nextInt(32))
52+
case BOOLEAN => Random.nextBoolean()
53+
case BINARY => randomBytes(Random.nextInt(32))
54+
case DATE => Random.nextInt()
5455
case TIMESTAMP =>
5556
val timestamp = new Timestamp(Random.nextLong())
5657
timestamp.setNanos(Random.nextInt(999999999))

sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
package org.apache.spark.sql.columnar
1919

20-
import org.apache.spark.sql.functions._
2120
import org.apache.spark.sql.TestData._
2221
import org.apache.spark.sql.catalyst.expressions.Row
2322
import org.apache.spark.sql.test.TestSQLContext._
2423
import org.apache.spark.sql.test.TestSQLContext.implicits._
24+
import org.apache.spark.sql.types.Decimal
2525
import org.apache.spark.sql.{QueryTest, TestData}
2626
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
2727

@@ -117,4 +117,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
117117
complexData.count()
118118
complexData.unpersist()
119119
}
120+
121+
test("decimal type") {
122+
(1 to 10)
123+
.map(i => Tuple1(Decimal(i, 20, 10)))
124+
.toDF("dec")
125+
.cache()
126+
.registerTempTable("test_fixed_decimal")
127+
128+
checkAnswer(
129+
sql("SELECT * FROM test_fixed_decimal"),
130+
(1 to 10).map(i => Row(Decimal(i, 20, 10).toJavaBigDecimal)))
131+
}
120132
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class NullableColumnAccessorSuite extends FunSuite {
4242
import ColumnarTestUtils._
4343

4444
Seq(
45-
INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP
45+
INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(20, 10), BINARY, GENERIC,
46+
DATE, TIMESTAMP
4647
).foreach {
4748
testNullableColumnAccessor(_)
4849
}

0 commit comments

Comments
 (0)