Skip to content

Commit f54f369

Browse files
fix bug of constant null value for ObjectInspector
1 parent 5f13759 commit f54f369

File tree

11 files changed

+135
-59
lines changed

11 files changed

+135
-59
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private[hive] trait HiveInspectors {
8888
* @return convert the data into catalyst type
8989
*/
9090
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
91+
case _ if data == null => null
9192
case hvoi: HiveVarcharObjectInspector =>
9293
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
9394
case hdoi: HiveDecimalObjectInspector =>
@@ -254,46 +255,59 @@ private[hive] trait HiveInspectors {
254255
}
255256

256257
def toInspector(expr: Expression): ObjectInspector = expr match {
257-
case Literal(value: String, StringType) =>
258-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
259-
case Literal(value: Int, IntegerType) =>
260-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
261-
case Literal(value: Double, DoubleType) =>
262-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
263-
case Literal(value: Boolean, BooleanType) =>
264-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
265-
case Literal(value: Long, LongType) =>
266-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
267-
case Literal(value: Float, FloatType) =>
268-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
269-
case Literal(value: Short, ShortType) =>
270-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
271-
case Literal(value: Byte, ByteType) =>
272-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
273-
case Literal(value: Array[Byte], BinaryType) =>
274-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
275-
case Literal(value: java.sql.Date, DateType) =>
276-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
277-
case Literal(value: java.sql.Timestamp, TimestampType) =>
278-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
279-
case Literal(value: BigDecimal, DecimalType()) =>
280-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
281-
case Literal(value: Decimal, DecimalType()) =>
282-
HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal)
258+
case Literal(value, StringType) =>
259+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[String])
260+
case Literal(value, IntegerType) =>
261+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Int])
262+
case Literal(value, DoubleType) =>
263+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Double])
264+
case Literal(value, BooleanType) =>
265+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Boolean])
266+
case Literal(value, LongType) =>
267+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Long])
268+
case Literal(value, FloatType) =>
269+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Float])
270+
case Literal(value, ShortType) =>
271+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Short])
272+
case Literal(value, ByteType) =>
273+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Byte])
274+
case Literal(value, BinaryType) =>
275+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[Array[Byte]])
276+
case Literal(value, DateType) =>
277+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[java.sql.Date])
278+
case Literal(value, TimestampType) =>
279+
HiveShim.getPrimitiveWritableConstantObjectInspector(value.asInstanceOf[java.sql.Timestamp])
280+
case Literal(value, DecimalType()) =>
281+
if (null == value) {
282+
HiveShim.getPrimitiveWritableConstantObjectInspector(
283+
null.asInstanceOf[BigDecimal])
284+
} else {
285+
HiveShim.getPrimitiveWritableConstantObjectInspector(
286+
value.asInstanceOf[Decimal].toBigDecimal)
287+
}
283288
case Literal(_, NullType) =>
284289
HiveShim.getPrimitiveNullWritableConstantObjectInspector
285-
case Literal(value: Seq[_], ArrayType(dt, _)) =>
290+
case Literal(value, ArrayType(dt, _)) =>
286291
val listObjectInspector = toInspector(dt)
287-
val list = new java.util.ArrayList[Object]()
288-
value.foreach(v => list.add(wrap(v, listObjectInspector)))
289-
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
290-
case Literal(map: Map[_, _], MapType(keyType, valueType, _)) =>
291-
val value = new java.util.HashMap[Object, Object]()
292+
if (value == null) {
293+
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
294+
} else {
295+
val list = new java.util.ArrayList[Object]()
296+
value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector)))
297+
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
298+
}
299+
case Literal(value, MapType(keyType, valueType, _)) =>
292300
val keyOI = toInspector(keyType)
293301
val valueOI = toInspector(valueType)
294-
map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)))
295-
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value)
296-
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
302+
if (value == null) {
303+
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null)
304+
} else {
305+
val map = new java.util.HashMap[Object, Object]()
306+
value.asInstanceOf[Map[_, _]].foreach (entry => {
307+
map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))
308+
})
309+
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
310+
}
297311
case _ => toInspector(expr.dataType)
298312
}
299313

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1970-01-01 08:00:00.001 NULL 1 NULL
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
There is no documentation for function 'if'
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
There is no documentation for function 'if'

sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 1 1 1 NULL 2

sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
128 1.1 ABC 12.3

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@ case class TestData(a: Int, b: String)
3636
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
3737
*/
3838
class HiveQuerySuite extends HiveComparisonTest {
39+
createQueryTest("constant null testing",
40+
"""SELECT
41+
|IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1,
42+
|IF(TRUE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL2,
43+
|IF(FALSE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL3,
44+
|IF(TRUE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL4,
45+
|IF(FALSE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL5,
46+
|IF(TRUE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL6,
47+
|IF(FALSE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL7,
48+
|IF(TRUE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL8,
49+
|IF(FALSE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL9,
50+
|IF(TRUE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL10,
51+
|IF(FALSE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL11,
52+
|IF(TRUE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL12,
53+
|IF(FALSE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL13,
54+
|IF(TRUE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL14,
55+
|IF(FALSE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL15,
56+
|IF(TRUE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL16,
57+
|IF(FALSE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL17,
58+
|IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18,
59+
|IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19,
60+
|IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20,
61+
|IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21,
62+
|IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22,
63+
|IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23,
64+
|IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24
65+
|FROM src LIMIT 1""".stripMargin)
66+
3967
createQueryTest("constant array",
4068
"""
4169
|SELECT sort_array(

sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,52 +59,57 @@ private[hive] object HiveShim {
5959

6060
def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector =
6161
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
62-
PrimitiveCategory.STRING, new hadoopIo.Text(value))
62+
PrimitiveCategory.STRING, if (value == null) null else new hadoopIo.Text(value))
6363

6464
def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector =
6565
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
66-
PrimitiveCategory.INT, new hadoopIo.IntWritable(value))
66+
PrimitiveCategory.INT, if (value == null) null else new hadoopIo.IntWritable(value))
6767

6868
def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector =
6969
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
70-
PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value))
70+
PrimitiveCategory.DOUBLE, if (value == null) null else new hiveIo.DoubleWritable(value))
7171

7272
def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector =
7373
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
74-
PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value))
74+
PrimitiveCategory.BOOLEAN, if (value == null) null else new hadoopIo.BooleanWritable(value))
7575

7676
def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector =
7777
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
78-
PrimitiveCategory.LONG, new hadoopIo.LongWritable(value))
78+
PrimitiveCategory.LONG, if (value == null) null else new hadoopIo.LongWritable(value))
7979

8080
def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector =
8181
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
82-
PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value))
82+
PrimitiveCategory.FLOAT, if (value == null) null else new hadoopIo.FloatWritable(value))
8383

8484
def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector =
8585
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
86-
PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value))
86+
PrimitiveCategory.SHORT, if (value == null) null else new hiveIo.ShortWritable(value))
8787

8888
def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector =
8989
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
90-
PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value))
90+
PrimitiveCategory.BYTE, if (value == null) null else new hiveIo.ByteWritable(value))
9191

9292
def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector =
9393
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
94-
PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value))
94+
PrimitiveCategory.BINARY, if (value == null) null else new hadoopIo.BytesWritable(value))
9595

9696
def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector =
9797
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
98-
PrimitiveCategory.DATE, new hiveIo.DateWritable(value))
98+
PrimitiveCategory.DATE, if (value == null) null else new hiveIo.DateWritable(value))
9999

100100
def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector =
101101
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
102-
PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value))
102+
PrimitiveCategory.TIMESTAMP,
103+
if (value == null) null else new hiveIo.TimestampWritable(value))
103104

104105
def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector =
105106
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
106107
PrimitiveCategory.DECIMAL,
107-
new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying())))
108+
if (value == null) {
109+
null
110+
} else {
111+
new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))
112+
})
108113

109114
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
110115
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(

0 commit comments

Comments
 (0)