Skip to content

Commit c9373c8

Browse files
committed
Support DecimalType.
1 parent 2379eeb commit c9373c8

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import java.io._
21+
import java.math.{BigDecimal, BigInteger}
2122
import java.nio.ByteBuffer
2223
import java.sql.Timestamp
2324

@@ -143,7 +144,6 @@ private[sql] object SparkSqlSerializer2 {
143144
case array: ArrayType => return false
144145
case map: MapType => return false
145146
case struct: StructType => return false
146-
case decimal: DecimalType => return false
147147
case _ =>
148148
}
149149
i += 1
@@ -223,6 +223,21 @@ private[sql] object SparkSqlSerializer2 {
223223
out.writeDouble(row.getDouble(i))
224224
}
225225

226+
case decimal: DecimalType =>
227+
if (row.isNullAt(i)) {
228+
out.writeByte(NULL)
229+
} else {
230+
out.writeByte(NOT_NULL)
231+
val value = row.apply(i).asInstanceOf[Decimal]
232+
val javaBigDecimal = value.toJavaBigDecimal
233+
// First, write out the unscaled value.
234+
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
235+
out.writeInt(bytes.length)
236+
out.write(bytes)
237+
// Then, write out the scale.
238+
out.writeInt(javaBigDecimal.scale())
239+
}
240+
226241
case DateType =>
227242
if (row.isNullAt(i)) {
228243
out.writeByte(NULL)
@@ -334,6 +349,21 @@ private[sql] object SparkSqlSerializer2 {
334349
mutableRow.setDouble(i, in.readDouble())
335350
}
336351

352+
case decimal: DecimalType =>
353+
if (in.readByte() == NULL) {
354+
mutableRow.setNullAt(i)
355+
} else {
356+
// First, read in the unscaled value.
357+
val length = in.readInt()
358+
val bytes = new Array[Byte](length)
359+
in.readFully(bytes)
360+
val unscaledVal = new BigInteger(bytes)
361+
// Then, read the scale.
362+
val scale = in.readInt()
363+
// Finally, create the Decimal object and set it in the row.
364+
mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
365+
}
366+
337367
case DateType =>
338368
if (in.readByte() == NULL) {
339369
mutableRow.setNullAt(i)

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,9 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
5353
checkSupported(TimestampType, isSupported = true)
5454
checkSupported(StringType, isSupported = true)
5555
checkSupported(BinaryType, isSupported = true)
56+
checkSupported(DecimalType(10, 5), isSupported = true)
57+
checkSupported(DecimalType.Unlimited, isSupported = true)
5658

57-
// Because at the runtime we accepts three kinds of Decimals
58-
// (Java BigDecimal, Scala BigDecimal, and Spark SQL's Decimal), we do support DecimalType
59-
// right now. We will support it once we fixed the internal type.
60-
checkSupported(DecimalType(10, 5), isSupported = false)
61-
checkSupported(DecimalType.Unlimited, isSupported = false)
6259
// For now, ArrayType, MapType, and StructType are not supported.
6360
checkSupported(ArrayType(DoubleType, true), isSupported = false)
6461
checkSupported(ArrayType(StringType, false), isSupported = false)
@@ -84,7 +81,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
8481
val supportedTypes =
8582
Seq(StringType, BinaryType, NullType, BooleanType,
8683
ByteType, ShortType, IntegerType, LongType,
87-
FloatType, DoubleType, DateType, TimestampType)
84+
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
85+
DateType, TimestampType)
8886

8987
val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
9088
StructField(s"col$index", dataType, true)
@@ -103,9 +101,11 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
103101
i.toByte,
104102
i.toShort,
105103
i,
106-
i.toLong,
104+
Long.MaxValue - i.toLong,
107105
(i + 0.25).toFloat,
108106
(i + 0.75),
107+
BigDecimal(Long.MaxValue.toString + ".12345"),
108+
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
109109
new Date(i),
110110
new Timestamp(i))
111111
}
@@ -159,7 +159,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
159159
checkSerializer(df.queryExecution.executedPlan, serializerClass)
160160
checkAnswer(
161161
df,
162-
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
162+
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
163163
}
164164
}
165165

0 commit comments

Comments
 (0)