Skip to content

Commit 0da254f

Browse files
chenghao-intelliancheng
authored andcommitted
[SPARK-6734] [SQL] Add UDTF.close support in Generate
Some third-party UDTF extensions generate additional rows in the "GenericUDTF.close()" method, which is supported / documented by Hive. https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF However, Spark SQL ignores the "GenericUDTF.close()", and it causes bug while porting job from Hive to Spark SQL. Author: Cheng Hao <[email protected]> Closes apache#5383 from chenghao-intel/udtf_close and squashes the following commits: 98b4e4b [Cheng Hao] Support UDTF.close
1 parent aa6ba3f commit 0da254f

File tree

7 files changed

+74
-13
lines changed

7 files changed

+74
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ abstract class Generator extends Expression {
5656

5757
/** Should be implemented by child classes to perform specific Generators. */
5858
override def eval(input: Row): TraversableOnce[Row]
59+
60+
/**
61+
* Notifies that there are no more rows to process, clean up code, and additional
62+
* rows can be made here.
63+
*/
64+
def terminate(): TraversableOnce[Row] = Nil
5965
}
6066

6167
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.expressions._
2323

24+
/**
25+
* For lazy computing, be sure the generator.terminate() called in the very last
26+
* TODO reusing the CompletionIterator?
27+
*/
28+
private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row])
29+
extends Iterator[Row] {
30+
31+
lazy val results = func().toIterator
32+
override def hasNext: Boolean = results.hasNext
33+
override def next(): Row = results.next()
34+
}
35+
2436
/**
2537
* :: DeveloperApi ::
2638
* Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the
@@ -47,27 +59,33 @@ case class Generate(
4759
val boundGenerator = BindReferences.bindReference(generator, child.output)
4860

4961
protected override def doExecute(): RDD[Row] = {
62+
// boundGenerator.terminate() should be triggered after all of the rows in the partition
5063
if (join) {
5164
child.execute().mapPartitions { iter =>
52-
val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
53-
// Used to produce rows with no matches when outer = true.
54-
val outerProjection =
55-
newProjection(child.output ++ nullValues, child.output)
56-
57-
val joinProjection = newProjection(output, output)
65+
val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
5866
val joinedRow = new JoinedRow
5967

60-
iter.flatMap {row =>
68+
iter.flatMap { row =>
69+
// we should always set the left (child output)
70+
joinedRow.withLeft(row)
6171
val outputRows = boundGenerator.eval(row)
6272
if (outer && outputRows.isEmpty) {
63-
outerProjection(row) :: Nil
73+
joinedRow.withRight(generatorNullRow) :: Nil
6474
} else {
65-
outputRows.map(or => joinProjection(joinedRow(row, or)))
75+
outputRows.map(or => joinedRow.withRight(or))
6676
}
77+
} ++ LazyIterator(() => boundGenerator.terminate()).map { row =>
78+
// we leave the left side as the last element of its child output
79+
// keep it the same as Hive does
80+
joinedRow.withRight(row)
6781
}
6882
}
6983
} else {
70-
child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
84+
child.execute().mapPartitions { iter =>
85+
iter.flatMap(row => boundGenerator.eval(row)) ++
86+
LazyIterator(() => boundGenerator.terminate())
87+
}
7188
}
7289
}
7390
}
91+

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,11 @@ private[hive] case class HiveGenericUdtf(
483483
extends Generator with HiveInspectors {
484484

485485
@transient
486-
protected lazy val function: GenericUDTF = funcWrapper.createFunction()
486+
protected lazy val function: GenericUDTF = {
487+
val fun: GenericUDTF = funcWrapper.createFunction()
488+
fun.setCollector(collector)
489+
fun
490+
}
487491

488492
@transient
489493
protected lazy val inputInspectors = children.map(toInspector)
@@ -494,6 +498,9 @@ private[hive] case class HiveGenericUdtf(
494498
@transient
495499
protected lazy val udtInput = new Array[AnyRef](children.length)
496500

501+
@transient
502+
protected lazy val collector = new UDTFCollector
503+
497504
lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
498505
field => (inspectorToDataType(field.getFieldObjectInspector), true)
499506
}
@@ -502,8 +509,7 @@ private[hive] case class HiveGenericUdtf(
502509
outputInspector // Make sure initialized.
503510

504511
val inputProjection = new InterpretedProjection(children)
505-
val collector = new UDTFCollector
506-
function.setCollector(collector)
512+
507513
function.process(wrap(inputProjection(input), inputInspectors, udtInput))
508514
collector.collectRows()
509515
}
@@ -525,6 +531,12 @@ private[hive] case class HiveGenericUdtf(
525531
}
526532
}
527533

534+
override def terminate(): TraversableOnce[Row] = {
535+
outputInspector // Make sure initialized.
536+
function.close()
537+
collector.collectRows()
538+
}
539+
528540
override def toString: String = {
529541
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
530542
}
1.3 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
97 500
2+
97 500
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
3
2+
3

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ package org.apache.spark.sql.hive.execution
2020
import java.io.File
2121
import java.util.{Locale, TimeZone}
2222

23+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
24+
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
25+
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, StructObjectInspector, ObjectInspector}
2326
import org.scalatest.BeforeAndAfter
2427

2528
import scala.util.Try
@@ -51,14 +54,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
5154
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
5255
// Add Locale setting
5356
Locale.setDefault(Locale.US)
57+
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
58+
// The function source code can be found at:
59+
// https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
60+
sql(
61+
"""
62+
|CREATE TEMPORARY FUNCTION udtf_count2
63+
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
64+
""".stripMargin)
5465
}
5566

5667
override def afterAll() {
5768
TestHive.cacheTables = false
5869
TimeZone.setDefault(originalTimeZone)
5970
Locale.setDefault(originalLocale)
71+
sql("DROP TEMPORARY FUNCTION udtf_count2")
6072
}
6173

74+
createQueryTest("Test UDTF.close in Lateral Views",
75+
"""
76+
|SELECT key, cc
77+
|FROM src LATERAL VIEW udtf_count2(value) dd AS cc
78+
""".stripMargin, false) // false mean we have to keep the temp function in registry
79+
80+
createQueryTest("Test UDTF.close in SELECT",
81+
"SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false)
82+
6283
test("SPARK-4908: concurrent hive native commands") {
6384
(1 to 100).par.map { _ =>
6485
sql("USE default")

0 commit comments

Comments
 (0)