Skip to content

Commit ff071e3

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-4250] [SQL] Fix bug of constant null value mapping to ConstantObjectInspector
Author: Cheng Hao <[email protected]> Closes #3114 from chenghao-intel/constant_null_oi and squashes the following commits: e603bda [Cheng Hao] fix the bug of null value for primitive types 50a13ba [Cheng Hao] fix the timezone issue f54f369 [Cheng Hao] fix bug of constant null value for ObjectInspector (cherry picked from commit fa77783) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 1ed1c68 commit ff071e3

File tree

11 files changed

+199
-86
lines changed

11 files changed

+199
-86
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private[hive] trait HiveInspectors {
8888
* @return convert the data into catalyst type
8989
*/
9090
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
91+
case _ if data == null => null
9192
case hvoi: HiveVarcharObjectInspector =>
9293
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
9394
case hdoi: HiveDecimalObjectInspector =>
@@ -250,46 +251,53 @@ private[hive] trait HiveInspectors {
250251
}
251252

252253
def toInspector(expr: Expression): ObjectInspector = expr match {
253-
case Literal(value: String, StringType) =>
254-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
255-
case Literal(value: Int, IntegerType) =>
256-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
257-
case Literal(value: Double, DoubleType) =>
258-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
259-
case Literal(value: Boolean, BooleanType) =>
260-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
261-
case Literal(value: Long, LongType) =>
262-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
263-
case Literal(value: Float, FloatType) =>
264-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
265-
case Literal(value: Short, ShortType) =>
266-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
267-
case Literal(value: Byte, ByteType) =>
268-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
269-
case Literal(value: Array[Byte], BinaryType) =>
270-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
271-
case Literal(value: java.sql.Date, DateType) =>
272-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
273-
case Literal(value: java.sql.Timestamp, TimestampType) =>
274-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
275-
case Literal(value: BigDecimal, DecimalType()) =>
276-
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
277-
case Literal(value: Decimal, DecimalType()) =>
278-
HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal)
254+
case Literal(value, StringType) =>
255+
HiveShim.getStringWritableConstantObjectInspector(value)
256+
case Literal(value, IntegerType) =>
257+
HiveShim.getIntWritableConstantObjectInspector(value)
258+
case Literal(value, DoubleType) =>
259+
HiveShim.getDoubleWritableConstantObjectInspector(value)
260+
case Literal(value, BooleanType) =>
261+
HiveShim.getBooleanWritableConstantObjectInspector(value)
262+
case Literal(value, LongType) =>
263+
HiveShim.getLongWritableConstantObjectInspector(value)
264+
case Literal(value, FloatType) =>
265+
HiveShim.getFloatWritableConstantObjectInspector(value)
266+
case Literal(value, ShortType) =>
267+
HiveShim.getShortWritableConstantObjectInspector(value)
268+
case Literal(value, ByteType) =>
269+
HiveShim.getByteWritableConstantObjectInspector(value)
270+
case Literal(value, BinaryType) =>
271+
HiveShim.getBinaryWritableConstantObjectInspector(value)
272+
case Literal(value, DateType) =>
273+
HiveShim.getDateWritableConstantObjectInspector(value)
274+
case Literal(value, TimestampType) =>
275+
HiveShim.getTimestampWritableConstantObjectInspector(value)
276+
case Literal(value, DecimalType()) =>
277+
HiveShim.getDecimalWritableConstantObjectInspector(value)
279278
case Literal(_, NullType) =>
280279
HiveShim.getPrimitiveNullWritableConstantObjectInspector
281-
case Literal(value: Seq[_], ArrayType(dt, _)) =>
280+
case Literal(value, ArrayType(dt, _)) =>
282281
val listObjectInspector = toInspector(dt)
283-
val list = new java.util.ArrayList[Object]()
284-
value.foreach(v => list.add(wrap(v, listObjectInspector)))
285-
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
286-
case Literal(map: Map[_, _], MapType(keyType, valueType, _)) =>
287-
val value = new java.util.HashMap[Object, Object]()
282+
if (value == null) {
283+
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
284+
} else {
285+
val list = new java.util.ArrayList[Object]()
286+
value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector)))
287+
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
288+
}
289+
case Literal(value, MapType(keyType, valueType, _)) =>
288290
val keyOI = toInspector(keyType)
289291
val valueOI = toInspector(valueType)
290-
map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)))
291-
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value)
292-
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
292+
if (value == null) {
293+
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null)
294+
} else {
295+
val map = new java.util.HashMap[Object, Object]()
296+
value.asInstanceOf[Map[_, _]].foreach (entry => {
297+
map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))
298+
})
299+
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
300+
}
293301
case _ => toInspector(expr.dataType)
294302
}
295303

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
There is no documentation for function 'if'
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
There is no documentation for function 'if'

sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 1 1 1 NULL 2

sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
128 1.1 ABC 12.3

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
package org.apache.spark.sql.hive.execution
1919

2020
import java.io.File
21+
import java.util.{Locale, TimeZone}
22+
23+
import org.scalatest.BeforeAndAfter
2124

2225
import scala.util.Try
2326

@@ -28,14 +31,59 @@ import org.apache.spark.sql.catalyst.plans.logical.Project
2831
import org.apache.spark.sql.hive._
2932
import org.apache.spark.sql.hive.test.TestHive
3033
import org.apache.spark.sql.hive.test.TestHive._
31-
import org.apache.spark.sql.{Row, SchemaRDD}
34+
import org.apache.spark.sql.{SQLConf, Row, SchemaRDD}
3235

3336
case class TestData(a: Int, b: String)
3437

3538
/**
3639
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
3740
*/
38-
class HiveQuerySuite extends HiveComparisonTest {
41+
class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
42+
private val originalTimeZone = TimeZone.getDefault
43+
private val originalLocale = Locale.getDefault
44+
45+
override def beforeAll() {
46+
TestHive.cacheTables = true
47+
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
48+
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
49+
// Add Locale setting
50+
Locale.setDefault(Locale.US)
51+
}
52+
53+
override def afterAll() {
54+
TestHive.cacheTables = false
55+
TimeZone.setDefault(originalTimeZone)
56+
Locale.setDefault(originalLocale)
57+
}
58+
59+
createQueryTest("constant null testing",
60+
"""SELECT
61+
|IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1,
62+
|IF(TRUE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL2,
63+
|IF(FALSE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL3,
64+
|IF(TRUE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL4,
65+
|IF(FALSE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL5,
66+
|IF(TRUE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL6,
67+
|IF(FALSE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL7,
68+
|IF(TRUE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL8,
69+
|IF(FALSE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL9,
70+
|IF(TRUE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL10,
71+
|IF(FALSE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL11,
72+
|IF(TRUE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL12,
73+
|IF(FALSE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL13,
74+
|IF(TRUE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL14,
75+
|IF(FALSE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL15,
76+
|IF(TRUE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL16,
77+
|IF(FALSE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL17,
78+
|IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18,
79+
|IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19,
80+
|IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20,
81+
|IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21,
82+
|IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22,
83+
|IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23,
84+
|IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24
85+
|FROM src LIMIT 1""".stripMargin)
86+
3987
createQueryTest("constant array",
4088
"""
4189
|SELECT sort_array(

sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,54 +57,74 @@ private[hive] object HiveShim {
5757
new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties)
5858
}
5959

60-
def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector =
60+
def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
6161
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
62-
PrimitiveCategory.STRING, new hadoopIo.Text(value))
62+
PrimitiveCategory.STRING,
63+
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]))
6364

64-
def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector =
65+
def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
6566
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
66-
PrimitiveCategory.INT, new hadoopIo.IntWritable(value))
67+
PrimitiveCategory.INT,
68+
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]))
6769

68-
def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector =
70+
def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
6971
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
70-
PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value))
72+
PrimitiveCategory.DOUBLE,
73+
if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]))
7174

72-
def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector =
75+
def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
7376
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
74-
PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value))
77+
PrimitiveCategory.BOOLEAN,
78+
if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]))
7579

76-
def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector =
80+
def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
7781
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
78-
PrimitiveCategory.LONG, new hadoopIo.LongWritable(value))
82+
PrimitiveCategory.LONG,
83+
if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]))
7984

80-
def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector =
85+
def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
8186
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
82-
PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value))
87+
PrimitiveCategory.FLOAT,
88+
if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]))
8389

84-
def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector =
90+
def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
8591
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
86-
PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value))
92+
PrimitiveCategory.SHORT,
93+
if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]))
8794

88-
def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector =
95+
def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
8996
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
90-
PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value))
97+
PrimitiveCategory.BYTE,
98+
if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]))
9199

92-
def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector =
100+
def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
93101
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
94-
PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value))
102+
PrimitiveCategory.BINARY,
103+
if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]))
95104

96-
def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector =
105+
def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
97106
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
98-
PrimitiveCategory.DATE, new hiveIo.DateWritable(value))
107+
PrimitiveCategory.DATE,
108+
if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]))
99109

100-
def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector =
110+
def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
101111
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
102-
PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value))
103-
104-
def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector =
112+
PrimitiveCategory.TIMESTAMP,
113+
if (value == null) {
114+
null
115+
} else {
116+
new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
117+
})
118+
119+
def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
105120
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
106121
PrimitiveCategory.DECIMAL,
107-
new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying())))
122+
if (value == null) {
123+
null
124+
} else {
125+
new hiveIo.HiveDecimalWritable(
126+
HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
127+
})
108128

109129
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
110130
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(

0 commit comments

Comments
 (0)