Skip to content

Commit bf1878d

Browse files
committed
Adds checks for non-aggregate attributes with aggregation
1 parent 7a3f589 commit bf1878d

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
6363
typeCoercionRules ++
6464
extendedRules : _*),
6565
Batch("Check Analysis", Once,
66-
CheckResolution),
66+
CheckResolution,
67+
CheckAggregation),
6768
Batch("AnalysisOperators", fixedPoint,
6869
EliminateAnalysisOperators)
6970
)
@@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
8889
}
8990
}
9091

92+
/**
93+
* Checks for non-aggregated attributes with aggregation
94+
*/
95+
object CheckAggregation extends Rule[LogicalPlan] {
96+
def apply(plan: LogicalPlan): LogicalPlan = {
97+
plan.transform {
98+
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
99+
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
100+
case _: AggregateExpression => true
101+
case e: Attribute if groupingExprs.contains(e) => true
102+
case e if groupingExprs.contains(e) => true
103+
case e if e.references.isEmpty => true
104+
case e => e.children.forall(isValidAggregateExpression)
105+
}
106+
107+
aggregateExprs.foreach { e =>
108+
if (!isValidAggregateExpression(e)) {
109+
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
110+
}
111+
}
112+
113+
aggregatePlan
114+
}
115+
}
116+
}
117+
91118
/**
92119
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
93120
*/
@@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
204231
*/
205232
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
206233
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
207-
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
234+
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
208235
if aggregate.resolved && containsAggregate(havingCondition) => {
209236
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
210237
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
211-
238+
212239
Project(aggregate.output,
213240
Filter(evaluatedCondition.toAttribute,
214241
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
215242
}
216-
217243
}
218-
244+
219245
protected def containsAggregate(condition: Expression): Boolean =
220246
condition
221247
.collect { case ae: AggregateExpression => ae }

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2121
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2223
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
2324
import org.apache.spark.sql.test._
2425
import org.scalatest.BeforeAndAfterAll
@@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
694695
checkAnswer(
695696
sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
696697
}
698+
699+
test("throw errors for non-aggregate attributes with aggregation") {
700+
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
701+
val logicalPlan = sql(query).queryExecution.logical
702+
703+
if (isInvalidQuery) {
704+
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
705+
assert(
706+
e.getMessage.startsWith("Expression not in GROUP BY"),
707+
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
708+
} else {
709+
// Should not throw
710+
sql(query).queryExecution.analyzed
711+
}
712+
}
713+
714+
checkAggregation("SELECT key, COUNT(*) FROM testData")
715+
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
716+
717+
checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
718+
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
719+
720+
checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
721+
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
722+
}
697723
}

0 commit comments

Comments
 (0)