Skip to content

Commit f515d69

Browse files
committed
add apply method and test cases
1 parent 8df6199 commit f515d69

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ def __init__(self, jc):
12751275

12761276
# container operators
12771277
__contains__ = _bin_op("contains")
1278-
__getitem__ = _bin_op("getItem")
1278+
__getitem__ = _bin_op("apply")
12791279

12801280
# bitwise operators
12811281
bitwiseOR = _bin_op("bitwiseOR")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,31 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
941941
checkEvaluation(resolveGetField('c.struct(typeS).at(2).getField("a")), "aa", row)
942942
}
943943

944+
test("error message of GetField") {
945+
val structType = StructType(StructField("a", StringType, true) :: Nil)
946+
val arrayStructType = ArrayType(structType)
947+
val arrayType = ArrayType(StringType)
948+
val otherType = StringType
949+
950+
def checkErrorMessage(
951+
childDataType: DataType,
952+
fieldDataType: DataType,
953+
errorMesage: String): Unit = {
954+
val e = intercept[org.apache.spark.sql.AnalysisException] {
955+
GetField(
956+
Literal.create(null, childDataType),
957+
Literal.create(null, fieldDataType),
958+
_ == _)
959+
}
960+
assert(e.getMessage().contains(errorMesage))
961+
}
962+
963+
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
964+
checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
965+
checkErrorMessage(arrayType, StringType, "Array index should be integral type")
966+
checkErrorMessage(otherType, StringType, "Can't get field on")
967+
}
968+
944969
test("arithmetic") {
945970
val row = create_row(1, 2, 3, null)
946971
val c1 = 'a.int.at(0)

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
6767

6868
override def hashCode: Int = this.expr.hashCode
6969

70+
/**
71+
* An expression that gets an item at a position out of an [[ArrayType]],
72+
* or gets a value by key in a [[MapType]],
73+
* or gets a field by name in a [[StructType]],
74+
* or gets an array of fields by name in an array of [[StructType]].
75+
*
76+
* @group expr_ops
77+
*/
78+
def apply(field: Any): Column = UnresolvedGetField(expr, Literal(field))
79+
7080
/**
7181
* Unary minus, i.e. negate the expression.
7282
* {{{

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
449449
testData.collect().map { case Row(key: Int, value: String) =>
450450
Row(key, value, key + 1)
451451
}.toSeq)
452-
assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
452+
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
453453
}
454454

455455
test("replace column using withColumn") {
@@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
484484
testData.collect().map { case Row(key: Int, value: String) =>
485485
Row(key, value, key + 1)
486486
}.toSeq)
487-
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
487+
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
488488
}
489489

490490
test("randomSplit") {
@@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
593593
Row(new java.math.BigDecimal(2.0)))
594594
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
595595
}
596+
597+
test("SPARK-7133: Implement struct, array, and map field accessor") {
598+
assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
599+
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
600+
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
601+
}
596602
}

0 commit comments

Comments
 (0)