Skip to content

Commit 7662ec2

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-5817] [SQL] Fix bug of udtf with column names
It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao <[email protected]> Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits
1 parent 2a24bf9 commit 7662ec2

26 files changed

+207
-145
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.util.collection.OpenHashSet
2121
import org.apache.spark.sql.AnalysisException
22-
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2322
import org.apache.spark.sql.catalyst.expressions._
2423
import org.apache.spark.sql.catalyst.plans.logical._
2524
import org.apache.spark.sql.catalyst.rules._
@@ -59,6 +58,7 @@ class Analyzer(
5958
ResolveReferences ::
6059
ResolveGroupingAnalytics ::
6160
ResolveSortReferences ::
61+
ResolveGenerate ::
6262
ImplicitGenerate ::
6363
ResolveFunctions ::
6464
GlobalAggregates ::
@@ -474,8 +474,59 @@ class Analyzer(
474474
*/
475475
object ImplicitGenerate extends Rule[LogicalPlan] {
476476
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
477-
case Project(Seq(Alias(g: Generator, _)), child) =>
478-
Generate(g, join = false, outer = false, None, child)
477+
case Project(Seq(Alias(g: Generator, name)), child) =>
478+
Generate(g, join = false, outer = false,
479+
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
480+
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
481+
Generate(g, join = false, outer = false,
482+
qualifier = None, names.map(UnresolvedAttribute(_)), child)
483+
}
484+
}
485+
486+
/**
487+
* Resolve the Generate, if the output names specified, we will take them, otherwise
488+
* we will try to provide the default names, which follow the same rule with Hive.
489+
*/
490+
object ResolveGenerate extends Rule[LogicalPlan] {
491+
// Construct the output attributes for the generator,
492+
// The output attribute names can be either specified or
493+
// auto generated.
494+
private def makeGeneratorOutput(
495+
generator: Generator,
496+
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
497+
val elementTypes = generator.elementTypes
498+
499+
if (generatorOutput.length == elementTypes.length) {
500+
generatorOutput.zip(elementTypes).map {
501+
case (a, (t, nullable)) if !a.resolved =>
502+
AttributeReference(a.name, t, nullable)()
503+
case (a, _) => a
504+
}
505+
} else if (generatorOutput.length == 0) {
506+
elementTypes.zipWithIndex.map {
507+
// keep the default column names as Hive does _c0, _c1, _cN
508+
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
509+
}
510+
} else {
511+
throw new AnalysisException(
512+
s"""
513+
|The number of aliases supplied in the AS clause does not match
514+
|the number of columns output by the UDTF expected
515+
|${elementTypes.size} aliases but got ${generatorOutput.size}
516+
""".stripMargin)
517+
}
518+
}
519+
520+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
521+
case p: Generate if !p.child.resolved || !p.generator.resolved => p
522+
case p: Generate if p.resolved == false =>
523+
// if the generator output names are not specified, we will use the default ones.
524+
Generate(
525+
p.generator,
526+
join = p.join,
527+
outer = p.outer,
528+
p.qualifier,
529+
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
479530
}
480531
}
481532
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ trait CheckAnalysis {
3838
throw new AnalysisException(msg)
3939
}
4040

41+
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
42+
exprs.flatMap(_.collect {
43+
case e: Generator => true
44+
}).length >= 1
45+
}
46+
4147
def checkAnalysis(plan: LogicalPlan): Unit = {
4248
// We transform up and order the rules so as to catch the first possible failure instead
4349
// of the result of cascading resolution failures.
@@ -110,6 +116,12 @@ trait CheckAnalysis {
110116
failAnalysis(
111117
s"unresolved operator ${operator.simpleString}")
112118

119+
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
120+
failAnalysis(
121+
s"""Only a single table generating function is allowed in a SELECT clause, found:
122+
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
123+
124+
113125
case _ => // Analysis successful!
114126
}
115127
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,13 @@ package object dsl {
284284
seed: Int = (math.random * 1000).toInt): LogicalPlan =
285285
Sample(fraction, withReplacement, seed, logicalPlan)
286286

287+
// TODO specify the output column names
287288
def generate(
288289
generator: Generator,
289290
join: Boolean = false,
290291
outer: Boolean = false,
291292
alias: Option[String] = None): LogicalPlan =
292-
Generate(generator, join, outer, None, logicalPlan)
293+
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
293294

294295
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
295296
InsertIntoTable(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,30 @@ abstract class Generator extends Expression {
4242

4343
override type EvaluatedType = TraversableOnce[Row]
4444

45-
override lazy val dataType =
46-
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
45+
// TODO ideally we should return the type of ArrayType(StructType),
46+
// however, we don't keep the output field names in the Generator.
47+
override def dataType: DataType = throw new UnsupportedOperationException
4748

4849
override def nullable: Boolean = false
4950

5051
/**
51-
* Should be overridden by specific generators. Called only once for each instance to ensure
52-
* that rule application does not change the output schema of a generator.
52+
* The output element data types in structure of Seq[(DataType, Nullable)]
53+
* TODO we probably need to add more information like metadata etc.
5354
*/
54-
protected def makeOutput(): Seq[Attribute]
55-
56-
private var _output: Seq[Attribute] = null
57-
58-
def output: Seq[Attribute] = {
59-
if (_output == null) {
60-
_output = makeOutput()
61-
}
62-
_output
63-
}
55+
def elementTypes: Seq[(DataType, Boolean)]
6456

6557
/** Should be implemented by child classes to perform specific Generators. */
6658
override def eval(input: Row): TraversableOnce[Row]
67-
68-
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
69-
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
70-
val copy = super.makeCopy(newArgs)
71-
copy._output = _output
72-
copy
73-
}
7459
}
7560

7661
/**
7762
* A generator that produces its output using the provided lambda function.
7863
*/
7964
case class UserDefinedGenerator(
80-
schema: Seq[Attribute],
65+
elementTypes: Seq[(DataType, Boolean)],
8166
function: Row => TraversableOnce[Row],
8267
children: Seq[Expression])
83-
extends Generator{
84-
85-
override protected def makeOutput(): Seq[Attribute] = schema
68+
extends Generator {
8669

8770
override def eval(input: Row): TraversableOnce[Row] = {
8871
// TODO(davies): improve this
@@ -98,30 +81,18 @@ case class UserDefinedGenerator(
9881
/**
9982
* Given an input array produces a sequence of rows for each value in the array.
10083
*/
101-
case class Explode(attributeNames: Seq[String], child: Expression)
84+
case class Explode(child: Expression)
10285
extends Generator with trees.UnaryNode[Expression] {
10386

10487
override lazy val resolved =
10588
child.resolved &&
10689
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
10790

108-
private lazy val elementTypes = child.dataType match {
91+
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
10992
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
11093
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
11194
}
11295

113-
// TODO: Move this pattern into Generator.
114-
protected def makeOutput() =
115-
if (attributeNames.size == elementTypes.size) {
116-
attributeNames.zip(elementTypes).map {
117-
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
118-
}
119-
} else {
120-
elementTypes.zipWithIndex.map {
121-
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
122-
}
123-
}
124-
12596
override def eval(input: Row): TraversableOnce[Row] = {
12697
child.dataType match {
12798
case ArrayType(_, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
112112
extends NamedExpression with trees.UnaryNode[Expression] {
113113

114114
override type EvaluatedType = Any
115+
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
116+
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
115117

116118
override def eval(input: Row): Any = child.eval(input)
117119

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
482482
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
483483

484484
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
485-
case filter @ Filter(condition,
486-
generate @ Generate(generator, join, outer, alias, grandChild)) =>
485+
case filter @ Filter(condition, g: Generate) =>
487486
// Predicates that reference attributes produced by the `Generate` operator cannot
488487
// be pushed below the operator.
489488
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
490-
conjunct => conjunct.references subsetOf grandChild.outputSet
489+
conjunct => conjunct.references subsetOf g.child.outputSet
491490
}
492491
if (pushDown.nonEmpty) {
493492
val pushDownPredicate = pushDown.reduce(And)
494-
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
493+
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
494+
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
495495
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
496496
} else {
497497
filter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
4040
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
4141
* programming with one important additional feature, which allows the input rows to be joined with
4242
* their output.
43+
* @param generator the generator expression
4344
* @param join when true, each output row is implicitly joined with the input tuple that produced
4445
* it.
4546
* @param outer when true, each input row will be output at least once, even if the output of the
4647
* given `generator` is empty. `outer` has no effect when `join` is false.
47-
* @param alias when set, this string is applied to the schema of the output of the transformation
48-
* as a qualifier.
48+
* @param qualifier Qualifier for the attributes of generator(UDTF)
49+
* @param generatorOutput The output schema of the Generator.
50+
* @param child Children logical plan node
4951
*/
5052
case class Generate(
5153
generator: Generator,
5254
join: Boolean,
5355
outer: Boolean,
54-
alias: Option[String],
56+
qualifier: Option[String],
57+
generatorOutput: Seq[Attribute],
5558
child: LogicalPlan)
5659
extends UnaryNode {
5760

58-
protected def generatorOutput: Seq[Attribute] = {
59-
val output = alias
60-
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
61-
.getOrElse(generator.output)
62-
if (join && outer) {
63-
output.map(_.withNullability(true))
64-
} else {
65-
output
66-
}
61+
override lazy val resolved: Boolean = {
62+
generator.resolved &&
63+
childrenResolved &&
64+
generator.elementTypes.length == generatorOutput.length &&
65+
!generatorOutput.exists(!_.resolved)
6766
}
6867

69-
override def output: Seq[Attribute] =
70-
if (join) child.output ++ generatorOutput else generatorOutput
68+
// we don't want the gOutput to be taken as part of the expressions
69+
// as that will cause exceptions like unresolved attributes etc.
70+
override def expressions: Seq[Expression] = generator :: Nil
71+
72+
def output: Seq[Attribute] = {
73+
val qualified = qualifier.map(q =>
74+
// prepend the new qualifier to the existed one
75+
generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
76+
).getOrElse(generatorOutput)
77+
78+
if (join) child.output ++ qualified else qualified
79+
}
7180
}
7281

7382
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
9090

9191
assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
9292

93-
val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
93+
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
9494
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
9595

9696
assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
454454
test("generate: predicate referenced no generated column") {
455455
val originalQuery = {
456456
testRelationWithArrayType
457-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
457+
.generate(Explode('c_arr), true, false, Some("arr"))
458458
.where(('b >= 5) && ('a > 6))
459459
}
460460
val optimized = Optimize(originalQuery.analyze)
461461
val correctAnswer = {
462462
testRelationWithArrayType
463463
.where(('b >= 5) && ('a > 6))
464-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
464+
.generate(Explode('c_arr), true, false, Some("arr")).analyze
465465
}
466466

467467
comparePlans(optimized, correctAnswer)
468468
}
469469

470470
test("generate: part of conjuncts referenced generated column") {
471-
val generator = Explode(Seq("c"), 'c_arr)
471+
val generator = Explode('c_arr)
472472
val originalQuery = {
473473
testRelationWithArrayType
474474
.generate(generator, true, false, Some("arr"))
@@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
499499
test("generate: all conjuncts referenced generated column") {
500500
val originalQuery = {
501501
testRelationWithArrayType
502-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
502+
.generate(Explode('c_arr), true, false, Some("arr"))
503503
.where(('c > 6) || ('b > 5)).analyze
504504
}
505505
val optimized = Optimize(originalQuery)

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.storage.StorageLevel
3636
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
37-
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
37+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
3838
import org.apache.spark.sql.catalyst.expressions._
3939
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
4040
import org.apache.spark.sql.catalyst.plans.logical._
@@ -711,12 +711,16 @@ class DataFrame private[sql](
711711
*/
712712
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
713713
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
714-
val attributes = schema.toAttributes
714+
715+
val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
716+
val names = schema.toAttributes.map(_.name)
717+
715718
val rowFunction =
716719
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
717-
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
720+
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
718721

719-
Generate(generator, join = true, outer = false, None, logicalPlan)
722+
Generate(generator, join = true, outer = false,
723+
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
720724
}
721725

722726
/**
@@ -733,12 +737,17 @@ class DataFrame private[sql](
733737
: DataFrame = {
734738
val dataType = ScalaReflection.schemaFor[B].dataType
735739
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
740+
// TODO handle the metadata?
741+
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
742+
val names = attributes.map(_.name)
743+
736744
def rowFunction(row: Row): TraversableOnce[Row] = {
737745
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
738746
}
739-
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
747+
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
740748

741-
Generate(generator, join = true, outer = false, None, logicalPlan)
749+
Generate(generator, join = true, outer = false,
750+
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
742751
}
743752

744753
/////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)