Skip to content

Commit 84d79ee

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-4244] [SQL] Support Hive Generic UDFs with constant object inspector parameters
Query `SELECT named_struct(lower("AA"), "12", lower("Bb"), "13") FROM src LIMIT 1` will throw exception, some of the Hive Generic UDF/UDAF requires the input object inspector is `ConstantObjectInspector`, however, we won't get that before the expression optimization executed. (Constant Folding). This PR is a work around to fix this. (As ideally, the `output` of LogicalPlan should be identical before and after Optimization). Author: Cheng Hao <[email protected]> Closes apache#3109 from chenghao-intel/optimized and squashes the following commits: 487ff79 [Cheng Hao] rebase to the latest master & update the unittest
1 parent d39f2e9 commit 84d79ee

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ private[hive] trait HiveInspectors {
326326
})
327327
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
328328
}
329+
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
330+
case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType))
329331
case _ => toInspector(expr.dataType)
330332
}
331333

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
2121

2222
import scala.collection.mutable.ArrayBuffer
2323

24-
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
24+
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
2525
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
2626
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
2727
import org.apache.hadoop.hive.ql.exec.{UDF, UDAF}
@@ -108,9 +108,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
108108
udfType != null && udfType.deterministic()
109109
}
110110

111-
override def foldable = {
112-
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
113-
}
111+
override def foldable = isUDFDeterministic && children.forall(_.foldable)
114112

115113
// Create parameter converters
116114
@transient
@@ -154,17 +152,17 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
154152
protected lazy val argumentInspectors = children.map(toInspector)
155153

156154
@transient
157-
protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
155+
protected lazy val returnInspector =
156+
function.initializeAndFoldConstants(argumentInspectors.toArray)
158157

159158
@transient
160159
protected lazy val isUDFDeterministic = {
161160
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
162161
(udfType != null && udfType.deterministic())
163162
}
164163

165-
override def foldable = {
166-
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
167-
}
164+
override def foldable =
165+
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
168166

169167
@transient
170168
protected lazy val deferedObjects =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
5656
Locale.setDefault(originalLocale)
5757
}
5858

59+
createQueryTest("constant object inspector for generic udf",
60+
"""SELECT named_struct(
61+
lower("AA"), "10",
62+
repeat(lower("AA"), 3), "11",
63+
lower(repeat("AA", 3)), "12",
64+
printf("Bb%d", 12), "13",
65+
repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""")
66+
5967
createQueryTest("NaN to Decimal",
6068
"SELECT CAST(CAST('NaN' AS DOUBLE) AS DECIMAL(1,1)) FROM src LIMIT 1")
6169

0 commit comments

Comments
 (0)