Skip to content

Commit c9d85f5

Browse files
committed
generalize UnresolvedGetField to support all map, struct, and array
1 parent cd1d411 commit c9d85f5

File tree

13 files changed

+229
-166
lines changed

13 files changed

+229
-166
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
375375
protected lazy val primary: PackratParser[Expression] =
376376
( literal
377377
| expression ~ ("[" ~> expression <~ "]") ^^
378-
{ case base ~ ordinal => GetItem(base, ordinal) }
378+
{ case base ~ ordinal => UnresolvedGetField(base, ordinal) }
379379
| (expression <~ ".") ~ ident ^^
380-
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
380+
{ case base ~ fieldName => UnresolvedGetField(base, Literal(fieldName)) }
381381
| cast
382382
| "(" ~> expression <~ ")"
383383
| function

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ class Analyzer(
311311
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
312312
logDebug(s"Resolving $u to $result")
313313
result
314-
case UnresolvedGetField(child, fieldName) if child.resolved =>
315-
GetField(child, fieldName, resolver)
314+
case UnresolvedGetField(child, fieldExpr) if child.resolved =>
315+
GetField(child, fieldExpr, resolver)
316316
}
317317
}
318318

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
184184
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
185185
}
186186

187-
case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
187+
case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends UnaryExpression {
188188
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
189189
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
190190
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -193,5 +193,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar
193193
override def eval(input: Row = null): EvaluatedType =
194194
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
195195

196-
override def toString: String = s"$child.$fieldName"
196+
override def toString: String = s"$child.getField($fieldExpr)"
197197
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ package object dsl {
100100
def isNull: Predicate = IsNull(expr)
101101
def isNotNull: Predicate = IsNotNull(expr)
102102

103-
def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
104-
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
103+
def getItem(ordinal: Expression): UnresolvedGetField = UnresolvedGetField(expr, ordinal)
104+
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, Literal(fieldName))
105105

106106
def cast(to: DataType): Expression = Cast(expr, to)
107107

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import scala.collection.Map
21+
22+
import org.apache.spark.sql.AnalysisException
23+
import org.apache.spark.sql.catalyst.analysis._
24+
import org.apache.spark.sql.types._
25+
26+
object GetField {
27+
/**
28+
* Returns the resolved `GetField`. It will return one kind of concrete `GetField`,
29+
* depend on the type of `child` and `fieldExpr`.
30+
*/
31+
def apply(
32+
child: Expression,
33+
fieldExpr: Expression,
34+
resolver: Resolver): GetField = {
35+
36+
(child.dataType, fieldExpr) match {
37+
case (StructType(fields), Literal(fieldName, StringType)) =>
38+
val ordinal = findField(fields, fieldName.toString, resolver)
39+
SimpleStructGetField(child, fields(ordinal), ordinal)
40+
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
41+
val ordinal = findField(fields, fieldName.toString, resolver)
42+
ArrayStructGetField(child, fields(ordinal), ordinal, containsNull)
43+
case (_: ArrayType, _) if fieldExpr.dataType.isInstanceOf[IntegralType] =>
44+
ArrayOrdinalGetField(child, fieldExpr)
45+
case (_: MapType, _) =>
46+
MapOrdinalGetField(child, fieldExpr)
47+
case (otherType, _) =>
48+
throw new AnalysisException(
49+
s"GetField is not valid on child of type $otherType with fieldExpr of type ${fieldExpr.dataType}")
50+
}
51+
}
52+
53+
def unapply(g: GetField): Option[(Expression, Expression)] = {
54+
g match {
55+
case _: StructGetField => Some((g.child, null))
56+
case o: OrdinalGetField => Some((o.child, o.ordinal))
57+
case _ => None
58+
}
59+
}
60+
61+
/**
62+
* find the ordinal of StructField, report error if no desired field or over one
63+
* desired fields are found.
64+
*/
65+
private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
66+
val checkField = (f: StructField) => resolver(f.name, fieldName)
67+
val ordinal = fields.indexWhere(checkField)
68+
if (ordinal == -1) {
69+
throw new AnalysisException(
70+
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
71+
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
72+
throw new AnalysisException(
73+
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
74+
} else {
75+
ordinal
76+
}
77+
}
78+
}
79+
80+
trait GetField extends UnaryExpression {
81+
self: Product =>
82+
83+
type EvaluatedType = Any
84+
}
85+
86+
abstract class StructGetField extends GetField {
87+
self: Product =>
88+
89+
def field: StructField
90+
91+
override def foldable: Boolean = child.foldable
92+
override def toString: String = s"$child.${field.name}"
93+
}
94+
95+
abstract class OrdinalGetField extends GetField {
96+
self: Product =>
97+
98+
def ordinal: Expression
99+
100+
/** `Null` is returned for invalid ordinals. */
101+
override def nullable: Boolean = true
102+
override def foldable: Boolean = child.foldable && ordinal.foldable
103+
override def toString: String = s"$child[$ordinal]"
104+
override def children: Seq[Expression] = child :: ordinal :: Nil
105+
106+
override def eval(input: Row): Any = {
107+
val value = child.eval(input)
108+
if (value == null) {
109+
null
110+
} else {
111+
val o = ordinal.eval(input)
112+
if (o == null) {
113+
null
114+
} else {
115+
evalNotNull(value, o)
116+
}
117+
}
118+
}
119+
120+
protected def evalNotNull(value: Any, ordinal: Any): Any
121+
}
122+
123+
/**
124+
* Returns the value of fields in the Struct `child`.
125+
*/
126+
case class SimpleStructGetField(child: Expression, field: StructField, ordinal: Int)
127+
extends StructGetField {
128+
129+
override def dataType: DataType = field.dataType
130+
override def nullable: Boolean = child.nullable || field.nullable
131+
132+
override def eval(input: Row): Any = {
133+
val baseValue = child.eval(input).asInstanceOf[Row]
134+
if (baseValue == null) null else baseValue(ordinal)
135+
}
136+
}
137+
138+
/**
139+
* Returns the array of value of fields in the Array of Struct `child`.
140+
*/
141+
case class ArrayStructGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
142+
extends StructGetField {
143+
144+
override def dataType: DataType = ArrayType(field.dataType, containsNull)
145+
override def nullable: Boolean = child.nullable
146+
147+
override def eval(input: Row): Any = {
148+
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
149+
if (baseValue == null) null else {
150+
baseValue.map { row =>
151+
if (row == null) null else row(ordinal)
152+
}
153+
}
154+
}
155+
}
156+
157+
/**
158+
* Returns the field at `ordinal` in the Array `child`
159+
*/
160+
case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
161+
extends OrdinalGetField {
162+
163+
override def dataType = child.dataType.asInstanceOf[ArrayType].elementType
164+
165+
override lazy val resolved = childrenResolved &&
166+
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
167+
168+
protected def evalNotNull(value: Any, ordinal: Any) = {
169+
// TODO: consider using Array[_] for ArrayType child to avoid
170+
// boxing of primitives
171+
val baseValue = value.asInstanceOf[Seq[_]]
172+
val index = ordinal.asInstanceOf[Int]
173+
if (index >= baseValue.size || index < 0) {
174+
null
175+
} else {
176+
baseValue(index)
177+
}
178+
}
179+
}
180+
181+
/**
182+
* Returns the value of key `ordinal` in Map `child`
183+
*/
184+
case class MapOrdinalGetField(child: Expression, ordinal: Expression)
185+
extends OrdinalGetField {
186+
187+
override def dataType = child.dataType.asInstanceOf[MapType].valueType
188+
189+
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
190+
191+
protected def evalNotNull(value: Any, ordinal: Any) = {
192+
val baseValue = value.asInstanceOf[Map[Any, _]]
193+
baseValue.get(ordinal).orNull
194+
}
195+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -17,139 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import scala.collection.Map
21-
22-
import org.apache.spark.sql.AnalysisException
23-
import org.apache.spark.sql.catalyst.analysis.Resolver
2420
import org.apache.spark.sql.types._
2521

26-
/**
27-
* Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`.
28-
*/
29-
case class GetItem(child: Expression, ordinal: Expression) extends Expression {
30-
type EvaluatedType = Any
31-
32-
val children: Seq[Expression] = child :: ordinal :: Nil
33-
/** `Null` is returned for invalid ordinals. */
34-
override def nullable: Boolean = true
35-
override def foldable: Boolean = child.foldable && ordinal.foldable
36-
37-
override def dataType: DataType = child.dataType match {
38-
case ArrayType(dt, _) => dt
39-
case MapType(_, vt, _) => vt
40-
}
41-
override lazy val resolved =
42-
childrenResolved &&
43-
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
44-
45-
override def toString: String = s"$child[$ordinal]"
46-
47-
override def eval(input: Row): Any = {
48-
val value = child.eval(input)
49-
if (value == null) {
50-
null
51-
} else {
52-
val key = ordinal.eval(input)
53-
if (key == null) {
54-
null
55-
} else {
56-
if (child.dataType.isInstanceOf[ArrayType]) {
57-
// TODO: consider using Array[_] for ArrayType child to avoid
58-
// boxing of primitives
59-
val baseValue = value.asInstanceOf[Seq[_]]
60-
val o = key.asInstanceOf[Int]
61-
if (o >= baseValue.size || o < 0) {
62-
null
63-
} else {
64-
baseValue(o)
65-
}
66-
} else {
67-
val baseValue = value.asInstanceOf[Map[Any, _]]
68-
baseValue.get(key).orNull
69-
}
70-
}
71-
}
72-
}
73-
}
74-
75-
76-
trait GetField extends UnaryExpression {
77-
self: Product =>
78-
79-
type EvaluatedType = Any
80-
override def foldable: Boolean = child.foldable
81-
override def toString: String = s"$child.${field.name}"
82-
83-
def field: StructField
84-
}
85-
86-
object GetField {
87-
/**
88-
* Returns the resolved `GetField`, and report error if no desired field or over one
89-
* desired fields are found.
90-
*/
91-
def apply(
92-
expr: Expression,
93-
fieldName: String,
94-
resolver: Resolver): GetField = {
95-
def findField(fields: Array[StructField]): Int = {
96-
val checkField = (f: StructField) => resolver(f.name, fieldName)
97-
val ordinal = fields.indexWhere(checkField)
98-
if (ordinal == -1) {
99-
throw new AnalysisException(
100-
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
101-
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
102-
throw new AnalysisException(
103-
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
104-
} else {
105-
ordinal
106-
}
107-
}
108-
expr.dataType match {
109-
case StructType(fields) =>
110-
val ordinal = findField(fields)
111-
StructGetField(expr, fields(ordinal), ordinal)
112-
case ArrayType(StructType(fields), containsNull) =>
113-
val ordinal = findField(fields)
114-
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
115-
case otherType =>
116-
throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
117-
}
118-
}
119-
}
120-
121-
/**
122-
* Returns the value of fields in the Struct `child`.
123-
*/
124-
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
125-
126-
override def dataType: DataType = field.dataType
127-
override def nullable: Boolean = child.nullable || field.nullable
128-
129-
override def eval(input: Row): Any = {
130-
val baseValue = child.eval(input).asInstanceOf[Row]
131-
if (baseValue == null) null else baseValue(ordinal)
132-
}
133-
}
134-
135-
/**
136-
* Returns the array of value of fields in the Array of Struct `child`.
137-
*/
138-
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
139-
extends GetField {
140-
141-
override def dataType: DataType = ArrayType(field.dataType, containsNull)
142-
override def nullable: Boolean = child.nullable
143-
144-
override def eval(input: Row): Any = {
145-
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
146-
if (baseValue == null) null else {
147-
baseValue.map { row =>
148-
if (row == null) null else row(ordinal)
149-
}
150-
}
151-
}
152-
}
15322

15423
/**
15524
* Returns an Array containing the evaluation of all children expressions.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
227227
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
228228
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
229229
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
230-
case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType)
231-
case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
232-
case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
233-
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
230+
case e @ GetField(Literal(null, _), _) => Literal.create(null, e.dataType)
231+
case e @ GetField(_, Literal(null, _)) => Literal.create(null, e.dataType)
234232
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
235233
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
236234
case e @ Count(expr) if !expr.nullable => Count(Literal(1))

0 commit comments

Comments
 (0)