Skip to content

Commit fc26b49

Browse files
committed
Add AttributeSet class, remove references from Expression.
1 parent ded6796 commit fc26b49

File tree

27 files changed

+119
-77
lines changed

27 files changed

+119
-77
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
132132
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
133133
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
134134
val resolved = unresolved.flatMap(child.resolveChildren)
135-
val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
135+
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
136136

137137
val missingInProject = requiredAttributes -- p.output
138138
if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
152152
)
153153

154154
logDebug(s"Grouping expressions: $groupingRelation")
155-
val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
156-
val missingInAggs = resolved -- a.outputSet
155+
val resolved = unresolved.flatMap(groupingRelation.resolve)
156+
val missingInAggs = resolved.filterNot(a.outputSet.contains)
157157
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
158158
if (missingInAggs.nonEmpty) {
159159
// Add missing grouping exprs and then project them away after the sort.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ case class UnresolvedRelation(
3939
alias: Option[String] = None) extends LeafNode {
4040
override def output = Nil
4141
override lazy val resolved = false
42+
def reference = Set.empty
4243
}
4344

4445
/**
@@ -66,7 +67,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
6667
override def dataType = throw new UnresolvedException(this, "dataType")
6768
override def foldable = throw new UnresolvedException(this, "foldable")
6869
override def nullable = throw new UnresolvedException(this, "nullable")
69-
override def references = children.flatMap(_.references).toSet
7070
override lazy val resolved = false
7171

7272
// Unresolved functions are transient at compile time and don't get evaluated during execution.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
class AttributeEquals(val a: Attribute) {
21+
override def hashCode() = a.exprId.hashCode()
22+
override def equals(other: Any) = other match {
23+
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
24+
case otherAttribute => false
25+
}
26+
}
27+
28+
object AttributeSet {
29+
def apply(baseSet: Seq[Attribute]) = {
30+
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
31+
}
32+
33+
// def apply(baseSet: Set[Attribute]) = {
34+
// new AttributeSet(baseSet.map(new AttributeEquals(_)))
35+
// }
36+
}
37+
38+
class AttributeSet(val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] {
39+
def contains(elem: NamedExpression): Boolean =
40+
baseSet.contains(new AttributeEquals(elem.toAttribute))
41+
42+
def +(elem: Attribute): AttributeSet =
43+
new AttributeSet(baseSet + new AttributeEquals(elem))
44+
45+
def -(elem: Attribute): AttributeSet =
46+
new AttributeSet(baseSet - new AttributeEquals(elem))
47+
48+
def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
49+
50+
def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
51+
52+
def --(other: Traversable[NamedExpression]) =
53+
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
54+
55+
def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
56+
57+
override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
58+
59+
def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
60+
61+
override def nonEmpty = baseSet.nonEmpty
62+
63+
override def toSeq = baseSet.toSeq.map(_.a)
64+
65+
override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3232

3333
type EvaluatedType = Any
3434

35-
override def references = Set.empty
36-
3735
override def toString = s"input[$ordinal]"
3836

3937
override def eval(input: Row): Any = input(ordinal)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
4141
*/
4242
def foldable: Boolean = false
4343
def nullable: Boolean
44-
def references: Set[Attribute]
44+
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
4545

4646
/** Returns the result of evaluating this expression on a given input Row */
4747
def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
230230

231231
override def foldable = left.foldable && right.foldable
232232

233-
override def references = left.references ++ right.references
234-
235233
override def toString = s"($left $symbol $right)"
236234
}
237235

@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
242240
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
243241
self: Product =>
244242

245-
override def references = child.references
243+
246244
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
2424
case object Rand extends LeafExpression {
2525
override def dataType = DoubleType
2626
override def nullable = false
27-
override def references = Set.empty
2827

2928
private[this] lazy val rand = new Random
3029

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
2424

2525
type EvaluatedType = Any
2626

27-
def references = children.flatMap(_.references).toSet
2827
def nullable = true
2928

3029
/** This method has been generated by this script

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
3131
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
3232
with trees.UnaryNode[Expression] {
3333

34-
override def references = child.references
3534
override def dataType = child.dataType
3635
override def nullable = child.nullable
3736

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
3535
type EvaluatedType = DynamicRow
3636

3737
def nullable = false
38-
def references = children.toSet
38+
3939
def dataType = DynamicType
4040

4141
override def eval(input: Row): DynamicRow = input match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ abstract class AggregateFunction
7878

7979
/** Base should return the generic aggregate expression that this function is computing */
8080
val base: AggregateExpression
81-
override def references = base.references
81+
8282
override def nullable = base.nullable
8383
override def dataType = base.dataType
8484

@@ -89,7 +89,7 @@ abstract class AggregateFunction
8989
}
9090

9191
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
92-
override def references = child.references
92+
9393
override def nullable = true
9494
override def dataType = child.dataType
9595
override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
119119
}
120120

121121
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
122-
override def references = child.references
122+
123123
override def nullable = true
124124
override def dataType = child.dataType
125125
override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
149149
}
150150

151151
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
152-
override def references = child.references
152+
153153
override def nullable = false
154154
override def dataType = LongType
155155
override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
166166
def this() = this(null)
167167

168168
override def children = expressions
169-
override def references = expressions.flatMap(_.references).toSet
169+
170170
override def nullable = false
171171
override def dataType = LongType
172172
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
184184
def this() = this(null)
185185

186186
override def children = expressions
187-
override def references = expressions.flatMap(_.references).toSet
188187
override def nullable = false
189188
override def dataType = ArrayType(expressions.head.dataType)
190189
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
219218
def this() = this(null)
220219

221220
override def children = inputSet :: Nil
222-
override def references = inputSet.references
223221
override def nullable = false
224222
override def dataType = LongType
225223
override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
248246

249247
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
250248
extends AggregateExpression with trees.UnaryNode[Expression] {
251-
override def references = child.references
249+
252250
override def nullable = false
253251
override def dataType = child.dataType
254252
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
257255

258256
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
259257
extends AggregateExpression with trees.UnaryNode[Expression] {
260-
override def references = child.references
258+
261259
override def nullable = false
262260
override def dataType = LongType
263261
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
266264

267265
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
268266
extends PartialAggregate with trees.UnaryNode[Expression] {
269-
override def references = child.references
267+
270268
override def nullable = false
271269
override def dataType = LongType
272270
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
284282
}
285283

286284
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
287-
override def references = child.references
285+
288286
override def nullable = false
289287
override def dataType = DoubleType
290288
override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
304302
}
305303

306304
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
307-
override def references = child.references
305+
308306
override def nullable = false
309307
override def dataType = child.dataType
310308
override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
322320
case class SumDistinct(child: Expression)
323321
extends AggregateExpression with trees.UnaryNode[Expression] {
324322

325-
override def references = child.references
323+
326324
override def nullable = false
327325
override def dataType = child.dataType
328326
override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
331329
}
332330

333331
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
334-
override def references = child.references
335332
override def nullable = true
336333
override def dataType = child.dataType
337334
override def toString = s"FIRST($child)"

0 commit comments

Comments
 (0)