Skip to content

Commit c3a00f2

Browse files
maropuuzadude
authored andcommitted
[SPARK-17683][SQL] Support ArrayType in Literal.apply
## What changes were proposed in this pull request? This pr is to add pattern-matching entries for array data in `Literal.apply`. ## How was this patch tested? Added tests in `LiteralExpressionSuite`. Author: Takeshi YAMAMURO <[email protected]> Closes apache#15257 from maropu/SPARK-17683.
1 parent 2d0978e commit c3a00f2

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.lang.{Boolean => JavaBoolean}
21+
import java.lang.{Byte => JavaByte}
22+
import java.lang.{Double => JavaDouble}
23+
import java.lang.{Float => JavaFloat}
24+
import java.lang.{Integer => JavaInteger}
25+
import java.lang.{Long => JavaLong}
26+
import java.lang.{Short => JavaShort}
27+
import java.math.{BigDecimal => JavaBigDecimal}
2028
import java.nio.charset.StandardCharsets
2129
import java.sql.{Date, Timestamp}
2230
import java.util
2331
import java.util.Objects
2432
import javax.xml.bind.DatatypeConverter
2533

34+
import scala.math.{BigDecimal, BigInt}
35+
2636
import org.json4s.JsonAST._
2737

38+
import org.apache.spark.sql.AnalysisException
2839
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2940
import org.apache.spark.sql.catalyst.expressions.codegen._
3041
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -46,19 +57,63 @@ object Literal {
4657
case s: String => Literal(UTF8String.fromString(s), StringType)
4758
case b: Boolean => Literal(b, BooleanType)
4859
case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale))
49-
case d: java.math.BigDecimal =>
60+
case d: JavaBigDecimal =>
5061
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
5162
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
5263
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
5364
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
5465
case a: Array[Byte] => Literal(a, BinaryType)
66+
case a: Array[_] =>
67+
val elementType = componentTypeToDataType(a.getClass.getComponentType())
68+
val dataType = ArrayType(elementType)
69+
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
70+
Literal(convert(a), dataType)
5571
case i: CalendarInterval => Literal(i, CalendarIntervalType)
5672
case null => Literal(null, NullType)
5773
case v: Literal => v
5874
case _ =>
5975
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
6076
}
6177

78+
/**
79+
* Returns the Spark SQL DataType for a given class object. Since this type needs to be resolved
80+
* in runtime, we use match-case idioms for class objects here. However, there are similar
81+
* functions in other files (e.g., HiveInspectors), so these functions need to merged into one.
82+
*/
83+
private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match {
84+
// primitive types
85+
case JavaShort.TYPE => ShortType
86+
case JavaInteger.TYPE => IntegerType
87+
case JavaLong.TYPE => LongType
88+
case JavaDouble.TYPE => DoubleType
89+
case JavaByte.TYPE => ByteType
90+
case JavaFloat.TYPE => FloatType
91+
case JavaBoolean.TYPE => BooleanType
92+
93+
// java classes
94+
case _ if clz == classOf[Date] => DateType
95+
case _ if clz == classOf[Timestamp] => TimestampType
96+
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
97+
case _ if clz == classOf[Array[Byte]] => BinaryType
98+
case _ if clz == classOf[JavaShort] => ShortType
99+
case _ if clz == classOf[JavaInteger] => IntegerType
100+
case _ if clz == classOf[JavaLong] => LongType
101+
case _ if clz == classOf[JavaDouble] => DoubleType
102+
case _ if clz == classOf[JavaByte] => ByteType
103+
case _ if clz == classOf[JavaFloat] => FloatType
104+
case _ if clz == classOf[JavaBoolean] => BooleanType
105+
106+
// other scala classes
107+
case _ if clz == classOf[String] => StringType
108+
case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT
109+
case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT
110+
case _ if clz == classOf[CalendarInterval] => CalendarIntervalType
111+
112+
case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType))
113+
114+
case _ => throw new AnalysisException(s"Unsupported component type $clz in arrays")
115+
}
116+
62117
/**
63118
* Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object
64119
* into code generation.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.Row
24+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2425
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2526
import org.apache.spark.sql.types._
2627
import org.apache.spark.unsafe.types.CalendarInterval
@@ -43,6 +44,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
4344
checkEvaluation(Literal.create(null, TimestampType), null)
4445
checkEvaluation(Literal.create(null, CalendarIntervalType), null)
4546
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
47+
checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null)
4648
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
4749
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
4850
}
@@ -122,5 +124,28 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
122124
}
123125
}
124126

125-
// TODO(davies): add tests for ArrayType, MapType and StructType
127+
test("array") {
128+
def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
129+
val toCatalyst = (a: Array[_], elementType: DataType) => {
130+
CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
131+
}
132+
checkEvaluation(Literal(a), toCatalyst(a, elementType))
133+
}
134+
checkArrayLiteral(Array(1, 2, 3), IntegerType)
135+
checkArrayLiteral(Array("a", "b", "c"), StringType)
136+
checkArrayLiteral(Array(1.0, 4.0), DoubleType)
137+
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
138+
CalendarIntervalType)
139+
}
140+
141+
test("unsupported types (map and struct) in literals") {
142+
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
143+
val errMsgMap = intercept[RuntimeException] {
144+
Literal(v)
145+
}
146+
assert(errMsgMap.getMessage.startsWith("Unsupported literal type"))
147+
}
148+
checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2))
149+
checkUnsupportedTypeInLiteral(("mike", 29, 1.0))
150+
}
126151
}

0 commit comments

Comments
 (0)