Skip to content

Commit f33d550

Browse files
gvramanamarmbrus
authored andcommitted
[SPARK-3891][SQL] Add array support to percentile, percentile_approx and constant inspectors support
Supported passing array to percentile and percentile_approx UDAFs To support percentile_approx, constant inspectors are supported for GenericUDAF Constant folding support added to CreateArray expression Avoided constant udf expression re-evaluation Author: Venkata Ramana G <ramana.gollamudihuawei.com> Author: Venkata Ramana Gollamudi <[email protected]> Closes #2802 from gvramana/percentile_array_support and squashes the following commits: a0182e5 [Venkata Ramana Gollamudi] fixed review comment a18f917 [Venkata Ramana Gollamudi] avoid constant udf expression re-evaluation - fixes failure due to return iterator and value type mismatch c46db0f [Venkata Ramana Gollamudi] Removed TestHive reset 4d39105 [Venkata Ramana Gollamudi] Unified inspector creation, style check fixes f37fd69 [Venkata Ramana Gollamudi] Fixed review comments 47f6365 [Venkata Ramana Gollamudi] fixed test cb7c61e [Venkata Ramana Gollamudi] Supported ConstantInspector for UDAF Fixed HiveUdaf wrap object issue. 7f94aff [Venkata Ramana Gollamudi] Added foldable support to CreateArray
1 parent 8d0d2a6 commit f33d550

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
113113
*/
114114
case class CreateArray(children: Seq[Expression]) extends Expression {
115115
override type EvaluatedType = Any
116-
116+
117+
override def foldable = !children.exists(!_.foldable)
118+
117119
lazy val childTypes = children.map(_.dataType).distinct
118120

119121
override lazy val resolved =

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
158158
override def foldable =
159159
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
160160

161+
@transient
162+
protected def constantReturnValue = unwrap(
163+
returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(),
164+
returnInspector)
165+
161166
@transient
162167
protected lazy val deferedObjects =
163168
argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
@@ -166,6 +171,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
166171

167172
override def eval(input: Row): Any = {
168173
returnInspector // Make sure initialized.
174+
if(foldable) return constantReturnValue
175+
169176
var i = 0
170177
while (i < children.length) {
171178
val idx = i
@@ -193,12 +200,13 @@ private[hive] case class HiveGenericUdaf(
193200

194201
@transient
195202
protected lazy val objectInspector = {
196-
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
203+
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
204+
resolver.getEvaluator(parameterInfo)
197205
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
198206
}
199207

200208
@transient
201-
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
209+
protected lazy val inspectors = children.map(toInspector)
202210

203211
def dataType: DataType = inspectorToDataType(objectInspector)
204212

@@ -223,12 +231,13 @@ private[hive] case class HiveUdaf(
223231

224232
@transient
225233
protected lazy val objectInspector = {
226-
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
234+
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
235+
resolver.getEvaluator(parameterInfo)
227236
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
228237
}
229238

230239
@transient
231-
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
240+
protected lazy val inspectors = children.map(toInspector)
232241

233242
def dataType: DataType = inspectorToDataType(objectInspector)
234243

@@ -261,7 +270,7 @@ private[hive] case class HiveGenericUdtf(
261270
protected lazy val function: GenericUDTF = funcWrapper.createFunction()
262271

263272
@transient
264-
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
273+
protected lazy val inputInspectors = children.map(toInspector)
265274

266275
@transient
267276
protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
@@ -334,10 +343,13 @@ private[hive] case class HiveUdafFunction(
334343
} else {
335344
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
336345
}
337-
338-
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
339-
340-
private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray)
346+
347+
private val inspectors = exprs.map(toInspector).toArray
348+
349+
private val function = {
350+
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
351+
resolver.getEvaluator(parameterInfo)
352+
}
341353

342354
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
343355

@@ -350,9 +362,12 @@ private[hive] case class HiveUdafFunction(
350362
@transient
351363
val inputProjection = new InterpretedProjection(exprs)
352364

365+
@transient
366+
protected lazy val cached = new Array[AnyRef](exprs.length)
367+
353368
def update(input: Row): Unit = {
354369
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
355-
function.iterate(buffer, inputs)
370+
function.iterate(buffer, wrap(inputs, inspectors, cached))
356371
}
357372
}
358373

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,21 @@ class HiveUdfSuite extends QueryTest {
9292
}
9393

9494
test("SPARK-2693 udaf aggregates test") {
95-
checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"),
95+
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
9696
sql("SELECT max(key) FROM src").collect().toSeq)
97+
98+
checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"),
99+
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
97100
}
98101

102+
test("Generic UDAF aggregates") {
103+
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"),
104+
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
105+
106+
checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"),
107+
sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq)
108+
}
109+
99110
test("UDFIntegerToString") {
100111
val testData = TestHive.sparkContext.parallelize(
101112
IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)

0 commit comments

Comments
 (0)