Skip to content

Commit f39969f

Browse files
cloud-fannemccarthy
authored andcommitted
[SPARK-7133] [SQL] Implement struct, array, and map field accessor
It's the first step: generalize UnresolvedGetField to support all map, struct, and array TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?). Author: Wenchen Fan <[email protected]> Closes apache#5744 from cloud-fan/generalize and squashes the following commits: 715c589 [Wenchen Fan] address comments 7ea5b31 [Wenchen Fan] fix python test 4f0833a [Wenchen Fan] add python test f515d69 [Wenchen Fan] add apply method and test cases 8df6199 [Wenchen Fan] fix python test 239730c [Wenchen Fan] fix test compile 2a70526 [Wenchen Fan] use _bin_op in dataframe.py 6bf72bc [Wenchen Fan] address comments 3f880c3 [Wenchen Fan] add java doc ab35ab5 [Wenchen Fan] fix python test b5961a9 [Wenchen Fan] fix style c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array
1 parent 462ea6f commit f39969f

File tree

16 files changed

+327
-191
lines changed

16 files changed

+327
-191
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ def __init__(self, jc):
12751275

12761276
# container operators
12771277
__contains__ = _bin_op("contains")
1278-
__getitem__ = _bin_op("getItem")
1278+
__getitem__ = _bin_op("apply")
12791279

12801280
# bitwise operators
12811281
bitwiseOR = _bin_op("bitwiseOR")
@@ -1308,19 +1308,19 @@ def getField(self, name):
13081308
>>> from pyspark.sql import Row
13091309
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
13101310
>>> df.select(df.r.getField("b")).show()
1311-
+---+
1312-
|r.b|
1313-
+---+
1314-
| b|
1315-
+---+
1311+
+----+
1312+
|r[b]|
1313+
+----+
1314+
| b|
1315+
+----+
13161316
>>> df.select(df.r.a).show()
1317-
+---+
1318-
|r.a|
1319-
+---+
1320-
| 1|
1321-
+---+
1317+
+----+
1318+
|r[a]|
1319+
+----+
1320+
| 1|
1321+
+----+
13221322
"""
1323-
return Column(self._jc.getField(name))
1323+
return self[name]
13241324

13251325
def __getattr__(self, item):
13261326
if item.startswith("__"):

python/pyspark/sql/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,13 @@ def test_access_nested_types(self):
519519
self.assertEqual("v", df.select(df.d["k"]).first()[0])
520520
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
521521

522+
def test_field_accessor(self):
523+
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
524+
self.assertEqual(1, df.select(df.l[0]).first()[0])
525+
self.assertEqual(1, df.select(df.r["a"]).first()[0])
526+
self.assertEqual("b", df.select(df.r["b"]).first()[0])
527+
self.assertEqual("v", df.select(df.d["k"]).first()[0])
528+
522529
def test_infer_long_type(self):
523530
longrow = [Row(f1='a', f2=100000000000000)]
524531
df = self.sc.parallelize(longrow).toDF()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
375375
protected lazy val primary: PackratParser[Expression] =
376376
( literal
377377
| expression ~ ("[" ~> expression <~ "]") ^^
378-
{ case base ~ ordinal => GetItem(base, ordinal) }
378+
{ case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
379379
| (expression <~ ".") ~ ident ^^
380-
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
380+
{ case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
381381
| cast
382382
| "(" ~> expression <~ ")"
383383
| function

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ class Analyzer(
348348
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
349349
logDebug(s"Resolving $u to $result")
350350
result
351-
case UnresolvedGetField(child, fieldName) if child.resolved =>
352-
GetField(child, fieldName, resolver)
351+
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
352+
ExtractValue(child, fieldExpr, resolver)
353353
}
354354
}
355355

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,17 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
184184
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
185185
}
186186

187-
case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
187+
/**
188+
* Extracts a value or values from an Expression
189+
*
190+
* @param child The expression to extract value from,
191+
* can be Map, Array, Struct or array of Structs.
192+
* @param extraction The expression to describe the extraction,
193+
* can be key of Map, index of Array, field name of Struct.
194+
*/
195+
case class UnresolvedExtractValue(child: Expression, extraction: Expression)
196+
extends UnaryExpression {
197+
188198
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
189199
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
190200
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -193,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar
193203
override def eval(input: Row = null): EvaluatedType =
194204
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
195205

196-
override def toString: String = s"$child.$fieldName"
206+
override def toString: String = s"$child[$extraction]"
197207
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import scala.language.implicitConversions
2323
import scala.reflect.runtime.universe.{TypeTag, typeTag}
2424

25-
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
25+
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -100,8 +100,9 @@ package object dsl {
100100
def isNull: Predicate = IsNull(expr)
101101
def isNotNull: Predicate = IsNotNull(expr)
102102

103-
def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
104-
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
103+
def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal)
104+
def getField(fieldName: String): UnresolvedExtractValue =
105+
UnresolvedExtractValue(expr, Literal(fieldName))
105106

106107
def cast(to: DataType): Expression = Cast(expr, to)
107108

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import scala.collection.Map
21+
22+
import org.apache.spark.sql.AnalysisException
23+
import org.apache.spark.sql.catalyst.analysis._
24+
import org.apache.spark.sql.types._
25+
26+
object ExtractValue {
27+
/**
28+
* Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
29+
* depend on the type of `child` and `extraction`.
30+
*
31+
* `child` | `extraction` | concrete `ExtractValue`
32+
* ----------------------------------------------------------------
33+
* Struct | Literal String | GetStructField
34+
* Array[Struct] | Literal String | GetArrayStructFields
35+
* Array | Integral type | GetArrayItem
36+
* Map | Any type | GetMapValue
37+
*/
38+
def apply(
39+
child: Expression,
40+
extraction: Expression,
41+
resolver: Resolver): ExtractValue = {
42+
43+
(child.dataType, extraction) match {
44+
case (StructType(fields), Literal(fieldName, StringType)) =>
45+
val ordinal = findField(fields, fieldName.toString, resolver)
46+
GetStructField(child, fields(ordinal), ordinal)
47+
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
48+
val ordinal = findField(fields, fieldName.toString, resolver)
49+
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
50+
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
51+
GetArrayItem(child, extraction)
52+
case (_: MapType, _) =>
53+
GetMapValue(child, extraction)
54+
case (otherType, _) =>
55+
val errorMsg = otherType match {
56+
case StructType(_) | ArrayType(StructType(_), _) =>
57+
s"Field name should be String Literal, but it's $extraction"
58+
case _: ArrayType =>
59+
s"Array index should be integral type, but it's ${extraction.dataType}"
60+
case other =>
61+
s"Can't extract value from $child"
62+
}
63+
throw new AnalysisException(errorMsg)
64+
}
65+
}
66+
67+
def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
68+
g match {
69+
case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
70+
case _ => Some((g.child, null))
71+
}
72+
}
73+
74+
/**
75+
* Find the ordinal of StructField, report error if no desired field or over one
76+
* desired fields are found.
77+
*/
78+
private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
79+
val checkField = (f: StructField) => resolver(f.name, fieldName)
80+
val ordinal = fields.indexWhere(checkField)
81+
if (ordinal == -1) {
82+
throw new AnalysisException(
83+
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
84+
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
85+
throw new AnalysisException(
86+
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
87+
} else {
88+
ordinal
89+
}
90+
}
91+
}
92+
93+
trait ExtractValue extends UnaryExpression {
94+
self: Product =>
95+
96+
type EvaluatedType = Any
97+
}
98+
99+
/**
100+
* Returns the value of fields in the Struct `child`.
101+
*/
102+
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
103+
extends ExtractValue {
104+
105+
override def dataType: DataType = field.dataType
106+
override def nullable: Boolean = child.nullable || field.nullable
107+
override def foldable: Boolean = child.foldable
108+
override def toString: String = s"$child.${field.name}"
109+
110+
override def eval(input: Row): Any = {
111+
val baseValue = child.eval(input).asInstanceOf[Row]
112+
if (baseValue == null) null else baseValue(ordinal)
113+
}
114+
}
115+
116+
/**
117+
* Returns the array of value of fields in the Array of Struct `child`.
118+
*/
119+
case class GetArrayStructFields(
120+
child: Expression,
121+
field: StructField,
122+
ordinal: Int,
123+
containsNull: Boolean) extends ExtractValue {
124+
125+
override def dataType: DataType = ArrayType(field.dataType, containsNull)
126+
override def nullable: Boolean = child.nullable
127+
override def foldable: Boolean = child.foldable
128+
override def toString: String = s"$child.${field.name}"
129+
130+
override def eval(input: Row): Any = {
131+
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
132+
if (baseValue == null) null else {
133+
baseValue.map { row =>
134+
if (row == null) null else row(ordinal)
135+
}
136+
}
137+
}
138+
}
139+
140+
abstract class ExtractValueWithOrdinal extends ExtractValue {
141+
self: Product =>
142+
143+
def ordinal: Expression
144+
145+
/** `Null` is returned for invalid ordinals. */
146+
override def nullable: Boolean = true
147+
override def foldable: Boolean = child.foldable && ordinal.foldable
148+
override def toString: String = s"$child[$ordinal]"
149+
override def children: Seq[Expression] = child :: ordinal :: Nil
150+
151+
override def eval(input: Row): Any = {
152+
val value = child.eval(input)
153+
if (value == null) {
154+
null
155+
} else {
156+
val o = ordinal.eval(input)
157+
if (o == null) {
158+
null
159+
} else {
160+
evalNotNull(value, o)
161+
}
162+
}
163+
}
164+
165+
protected def evalNotNull(value: Any, ordinal: Any): Any
166+
}
167+
168+
/**
169+
* Returns the field at `ordinal` in the Array `child`
170+
*/
171+
case class GetArrayItem(child: Expression, ordinal: Expression)
172+
extends ExtractValueWithOrdinal {
173+
174+
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
175+
176+
override lazy val resolved = childrenResolved &&
177+
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
178+
179+
protected def evalNotNull(value: Any, ordinal: Any) = {
180+
// TODO: consider using Array[_] for ArrayType child to avoid
181+
// boxing of primitives
182+
val baseValue = value.asInstanceOf[Seq[_]]
183+
val index = ordinal.asInstanceOf[Int]
184+
if (index >= baseValue.size || index < 0) {
185+
null
186+
} else {
187+
baseValue(index)
188+
}
189+
}
190+
}
191+
192+
/**
193+
* Returns the value of key `ordinal` in Map `child`
194+
*/
195+
case class GetMapValue(child: Expression, ordinal: Expression)
196+
extends ExtractValueWithOrdinal {
197+
198+
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
199+
200+
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
201+
202+
protected def evalNotNull(value: Any, ordinal: Any) = {
203+
val baseValue = value.asInstanceOf[Map[Any, _]]
204+
baseValue.get(ordinal).orNull
205+
}
206+
}

0 commit comments

Comments
 (0)