Skip to content

Commit 554403f

Browse files
committed
[SQL] Improve DataFrame API error reporting
1. Throw UnsupportedOperationException if a Column is not computable. 2. Perform eager analysis on DataFrame so we can catch errors when they happen (not when an action is run). Author: Reynold Xin <[email protected]> Author: Davies Liu <[email protected]> Closes #4296 from rxin/col-computability and squashes the following commits: 6527b86 [Reynold Xin] Merge pull request #8 from davies/col-computability fd92bc7 [Reynold Xin] Merge branch 'master' into col-computability f79034c [Davies Liu] fix python tests 5afe1ff [Reynold Xin] Fix scala test. 17f6bae [Reynold Xin] Various fixes. b932e86 [Reynold Xin] Added eager analysis for error reporting. e6f00b8 [Reynold Xin] [SQL][API] ComputableColumn vs IncomputableColumn
1 parent eccb9fb commit 554403f

File tree

20 files changed

+896
-381
lines changed

20 files changed

+896
-381
lines changed

python/pyspark/sql.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,10 @@ def head(self, n=None):
21242124
return rs[0] if rs else None
21252125
return self.take(n)
21262126

2127+
def first(self):
2128+
""" Return the first row. """
2129+
return self.head()
2130+
21272131
def tail(self):
21282132
raise NotImplemented
21292133

@@ -2159,7 +2163,7 @@ def select(self, *cols):
21592163
else:
21602164
cols = [c._jc for c in cols]
21612165
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
2162-
jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
2166+
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
21632167
return DataFrame(jdf, self.sql_ctx)
21642168

21652169
def filter(self, condition):
@@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
21892193
else:
21902194
cols = [c._jc for c in cols]
21912195
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
2192-
jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
2196+
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
21932197
return GroupedDataFrame(jdf, self.sql_ctx)
21942198

21952199
def agg(self, *exprs):
@@ -2278,14 +2282,17 @@ def agg(self, *exprs):
22782282
:param exprs: list or aggregate columns or a map from column
22792283
name to agregate methods.
22802284
"""
2285+
assert exprs, "exprs should not be empty"
22812286
if len(exprs) == 1 and isinstance(exprs[0], dict):
22822287
jmap = MapConverter().convert(exprs[0],
22832288
self.sql_ctx._sc._gateway._gateway_client)
22842289
jdf = self._jdf.agg(jmap)
22852290
else:
22862291
# Columns
2287-
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
2288-
jdf = self._jdf.agg(*exprs)
2292+
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
2293+
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
2294+
self.sql_ctx._sc._gateway._gateway_client)
2295+
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
22892296
return DataFrame(jdf, self.sql_ctx)
22902297

22912298
@dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
23472354

23482355
def _create_column_from_name(name):
23492356
sc = SparkContext._active_spark_context
2350-
return sc._jvm.Column(name)
2357+
return sc._jvm.IncomputableColumn(name)
23512358

23522359

23532360
def _scalaMethod(name):
@@ -2371,7 +2378,7 @@ def _(self):
23712378
return _
23722379

23732380

2374-
def _bin_op(name, pass_literal_through=False):
2381+
def _bin_op(name, pass_literal_through=True):
23752382
""" Create a method for given binary operator
23762383
23772384
Keyword arguments:
@@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
24652472
# __getattr__ = _bin_op("getField")
24662473

24672474
# string methods
2468-
rlike = _bin_op("rlike", pass_literal_through=True)
2469-
like = _bin_op("like", pass_literal_through=True)
2470-
startswith = _bin_op("startsWith", pass_literal_through=True)
2471-
endswith = _bin_op("endsWith", pass_literal_through=True)
2475+
rlike = _bin_op("rlike")
2476+
like = _bin_op("like")
2477+
startswith = _bin_op("startsWith")
2478+
endswith = _bin_op("endsWith")
24722479
upper = _unary_op("upper")
24732480
lower = _unary_op("lower")
24742481

24752482
def substr(self, startPos, pos):
24762483
if type(startPos) != type(pos):
24772484
raise TypeError("Can not mix the type")
24782485
if isinstance(startPos, (int, long)):
2479-
24802486
jc = self._jc.substr(startPos, pos)
24812487
elif isinstance(startPos, Column):
24822488
jc = self._jc.substr(startPos._jc, pos._jc)
@@ -2507,30 +2513,53 @@ def cast(self, dataType):
25072513
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
25082514

25092515

2516+
def _to_java_column(col):
2517+
if isinstance(col, Column):
2518+
jcol = col._jc
2519+
else:
2520+
jcol = _create_column_from_name(col)
2521+
return jcol
2522+
2523+
25102524
def _aggregate_func(name):
25112525
""" Create a function for aggregator by name"""
25122526
def _(col):
25132527
sc = SparkContext._active_spark_context
2514-
if isinstance(col, Column):
2515-
jcol = col._jc
2516-
else:
2517-
jcol = _create_column_from_name(col)
2518-
jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
2528+
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
25192529
return Column(jc)
2530+
25202531
return staticmethod(_)
25212532

25222533

25232534
class Aggregator(object):
25242535
"""
25252536
A collections of builtin aggregators
25262537
"""
2527-
max = _aggregate_func("max")
2528-
min = _aggregate_func("min")
2529-
avg = mean = _aggregate_func("mean")
2530-
sum = _aggregate_func("sum")
2531-
first = _aggregate_func("first")
2532-
last = _aggregate_func("last")
2533-
count = _aggregate_func("count")
2538+
AGGS = [
2539+
'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
2540+
'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
2541+
]
2542+
for _name in AGGS:
2543+
locals()[_name] = _aggregate_func(_name)
2544+
del _name
2545+
2546+
@staticmethod
2547+
def countDistinct(col, *cols):
2548+
sc = SparkContext._active_spark_context
2549+
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
2550+
sc._gateway._gateway_client)
2551+
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
2552+
sc._jvm.Dsl.toColumns(jcols))
2553+
return Column(jc)
2554+
2555+
@staticmethod
2556+
def approxCountDistinct(col, rsd=None):
2557+
sc = SparkContext._active_spark_context
2558+
if rsd is None:
2559+
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
2560+
else:
2561+
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
2562+
return Column(jc)
25342563

25352564

25362565
def _test():

python/pyspark/tests.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,9 +1029,11 @@ def test_aggregator(self):
10291029
g = df.groupBy()
10301030
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
10311031
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
1032-
# TODO(davies): fix aggregators
1032+
10331033
from pyspark.sql import Aggregator as Agg
1034-
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
1034+
self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
1035+
self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
1036+
self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
10351037

10361038
def test_help_command(self):
10371039
# Regression test for SPARK-5464

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
822822
* have a name matching the given name, `null` will be returned.
823823
*/
824824
def apply(name: String): StructField = {
825-
nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
825+
nameToField.getOrElse(name,
826+
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
826827
}
827828

828829
/**

0 commit comments

Comments
 (0)