17
17
18
18
package org .apache .spark .sql .catalyst .expressions
19
19
20
- class AttributeEquals (val a : Attribute ) {
20
+ protected class AttributeEquals (val a : Attribute ) {
21
21
override def hashCode () = a.exprId.hashCode()
22
22
override def equals (other : Any ) = other match {
23
23
case otherReference : AttributeEquals => a.exprId == otherReference.a.exprId
@@ -26,47 +26,71 @@ class AttributeEquals(val a: Attribute) {
26
26
}
27
27
28
28
object AttributeSet {
29
+ /** Constructs a new [[AttributeSet ]] given a sequence of [[Attribute Attributes ]]. */
29
30
def apply (baseSet : Seq [Attribute ]) = {
30
31
new AttributeSet (baseSet.map(new AttributeEquals (_)).toSet)
31
32
}
32
-
33
- // def apply(baseSet: Set[Attribute]) = {
34
- // new AttributeSet(baseSet.map(new AttributeEquals(_)))
35
- // }
36
33
}
37
34
38
- class AttributeSet (val baseSet : Set [AttributeEquals ]) extends Traversable [Attribute ] {
35
+ /**
36
+ * A Set designed to hold [[AttributeReference ]] objects, that performs equality checking using
37
+ * expression id instead of standard java equality. Using expression id means that these
38
+ * sets will correctly test for membership, even when the AttributeReferences in question differ
39
+ * cosmetically (e.g., the names have different capitalizations).
40
+ */
41
+ class AttributeSet protected (val baseSet : Set [AttributeEquals ]) extends Traversable [Attribute ] {
39
42
43
+ /** Returns true if the members of this AttributeSet and other are the same. */
40
44
override def equals (other : Any ) = other match {
41
45
case otherSet : AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
42
46
case _ => false
43
47
}
44
48
49
+ /** Returns true if this set contains an Attribute with the same expression id as `elem` */
45
50
def contains (elem : NamedExpression ): Boolean =
46
51
baseSet.contains(new AttributeEquals (elem.toAttribute))
47
52
53
+ /** Returns a new [[AttributeSet ]] that contains `elem` in addition to the current elements. */
48
54
def + (elem : Attribute ): AttributeSet =
49
55
new AttributeSet (baseSet + new AttributeEquals (elem))
50
56
57
+ /** Returns a new [[AttributeSet ]] that does not contain `elem`. */
51
58
def - (elem : Attribute ): AttributeSet =
52
59
new AttributeSet (baseSet - new AttributeEquals (elem))
53
60
61
+ /** Returns an iterator containing all of the attributes in the set. */
54
62
def iterator : Iterator [Attribute ] = baseSet.map(_.a).iterator
55
63
64
+ /**
65
+ * Returns true if the [[Attribute Attributes ]] in this set are a subset of the Attributes in
66
+ * `other`.
67
+ */
56
68
def subsetOf (other : AttributeSet ) = baseSet.subsetOf(other.baseSet)
57
69
70
+ /**
71
+ * Returns a new [[AttributeSet ]] that does not contain any of the [[Attribute Attributes ]] found
72
+ * in `other`.
73
+ */
58
74
def -- (other : Traversable [NamedExpression ]) =
59
75
new AttributeSet (baseSet -- other.map(a => new AttributeEquals (a.toAttribute)))
60
76
77
+ /**
78
+ * Returns a new [[AttributeSet ]] that contains all of the [[Attribute Attributes ]] found
79
+ * in `other`.
80
+ */
61
81
def ++ (other : AttributeSet ) = new AttributeSet (baseSet ++ other.baseSet)
62
82
83
+ /**
84
+ * Returns a new [[AttributeSet ]] contain only the [[Attribute Attributes ]] where `f` evaluates to
85
+ * true.
86
+ */
63
87
override def filter (f : Attribute => Boolean ) = new AttributeSet (baseSet.filter(ae => f(ae.a)))
64
88
89
+ /**
90
+ * Returns a new [[AttributeSet ]] that only contains [[Attribute Attributes ]] that are found in
91
+ * `this` and `other`.
92
+ */
65
93
def intersect (other : AttributeSet ) = new AttributeSet (baseSet.intersect(other.baseSet))
66
94
67
- override def nonEmpty = baseSet.nonEmpty
68
-
69
- override def toSeq = baseSet.toSeq.map(_.a)
70
-
71
95
override def foreach [U ](f : (Attribute ) => U ): Unit = baseSet.map(_.a).foreach(f)
72
- }
96
+ }
0 commit comments