Skip to content

[SQL] Improve error messages #4558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
import org.apache.spark.sql.types._

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -66,32 +66,82 @@ class Analyzer(catalog: Catalog,
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution ::
CheckAggregation ::
Nil: _*),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
CheckResolution),
Batch("Remove SubQueries", fixedPoint,
EliminateSubQueries)
)

/**
* Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }

def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case p if p.expressions.exists(!_.resolved) =>
val missing = p.expressions.filterNot(_.resolved).map(_.prettyString).mkString(",")
val from = p.inputSet.map(_.name).mkString("{", ", ", "}")

throw new AnalysisException(s"Cannot resolve '$missing' given input columns $from")
case p if !p.resolved && p.childrenResolved =>
throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
} match {
// As a backstop, use the root node to check that the entire plan tree is resolved.
case p if !p.resolved =>
throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
case p => p
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")
}

operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}

val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})

cleaned.foreach(checkValidAggregateExpression)

case o if o.children.nonEmpty &&
!o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
val input = o.inputSet.map(_.prettyString).mkString(",")

failAnalysis(s"resolved attributes $missingAttributes missing from $input")

// Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case _ => // Analysis successful!
}
}

plan
}
}

Expand Down Expand Up @@ -192,45 +242,14 @@ class Analyzer(catalog: Catalog,
}
}

/**
* Checks for non-aggregated attributes with aggregation
*/
object CheckAggregation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.find { e =>
!isValidAggregateExpression(e.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g: GetField, _) => g
})
}.foreach { e =>
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}

aggregatePlan
}
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
i.copy(
table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias)))
table = EliminateSubQueries(catalog.lookupRelation(tableIdentifier, alias)))
case UnresolvedRelation(tableIdentifier, alias) =>
catalog.lookupRelation(tableIdentifier, alias)
}
Expand Down Expand Up @@ -477,7 +496,7 @@ class Analyzer(catalog: Catalog,
* only required to provide scoping information for attributes and can be removed once analysis is
* complete.
*/
object EliminateAnalysisOperators extends Rule[LogicalPlan] {
object EliminateSubQueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Subquery(_, child) => child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.Star

protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a.exprId.hashCode()
override def hashCode() = a match {
case ar: AttributeReference => ar.exprId.hashCode()
case a => a.hashCode()
}

override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E
override def exprId: ExprId = ???
override def eval(input: Row): EvaluatedType = ???
override def nullable: Boolean = ???
override def dataType: DataType = ???
override def dataType: DataType = NullType
}

object VirtualColumn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression}

/**
* Transforms the input by forking and running the specified script.
Expand All @@ -32,7 +32,9 @@ case class ScriptTransformation(
script: String,
output: Seq[Attribute],
child: LogicalPlan,
ioschema: ScriptInputOutputSchema) extends UnaryNode
ioschema: ScriptInputOutputSchema) extends UnaryNode {
override def references: AttributeSet = AttributeSet(input.flatMap(_.references))
}

/**
* A placeholder for implementation specific input and output properties when passing data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
children.foreach(_.foreach(f))
}

/**
* Runs the given function recursively on [[children]] then on this node.
* @param f the function to be applied to each node in the tree.
*/
def foreachUp(f: BaseType => Unit): Unit = {
children.foreach(_.foreach(f))
f(this)
}

/**
* Returns a Seq containing the result of applying the given function to each
* node in this tree in a preorder traversal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Literal, Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -108,24 +108,56 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
testRelation)
}

test("throw errors for unresolved attributes during analysis") {
val e = intercept[AnalysisException] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
def errorTest(
name: String,
plan: LogicalPlan,
errorMessages: Seq[String],
caseSensitive: Boolean = true) = {
test(name) {
val error = intercept[AnalysisException] {
if(caseSensitive) {
caseSensitiveAnalyze(plan)
} else {
caseInsensitiveAnalyze(plan)
}
}

errorMessages.foreach(m => assert(error.getMessage contains m))
}
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
}

test("throw errors for unresolved plans during analysis") {
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}
val e = intercept[AnalysisException] {
caseSensitiveAnalyze(UnresolvedTestPlan())
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)

errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)

errorTest(
"non-boolean filters",
testRelation.where(Literal(1)),
"filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)

errorTest(
"missing group by",
testRelation2.groupBy('a)('b),
"'b'" :: "group by" :: Nil
)

case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}

errorTest(
"catch all unresolved plan",
UnresolvedTestPlan(),
"unresolved" :: Nil)


test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
Expand All @@ -134,18 +166,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())

val expr0 = 'a / 2
val expr1 = 'a / 'b
val expr2 = 'a / 'c
val expr3 = 'a / 'd
val expr4 = 'e / 'e
val plan = caseInsensitiveAnalyze(Project(
Alias(expr0, s"Analyzer($expr0)")() ::
Alias(expr1, s"Analyzer($expr1)")() ::
Alias(expr2, s"Analyzer($expr2)")() ::
Alias(expr3, s"Analyzer($expr3)")() ::
Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
val plan = caseInsensitiveAnalyze(
testRelation2.select(
'a / Literal(2) as 'div1,
'a / 'b as 'div2,
'a / 'c as 'div3,
'a / 'd as 'div4,
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList

assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand All @@ -30,7 +30,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateAnalysisOperators) ::
EliminateSubQueries) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand All @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateAnalysisOperators) ::
EliminateSubQueries) ::
Batch("ConstantFolding", Once,
ConstantFolding,
BooleanSimplification) :: Nil
Expand Down
Loading