@@ -2124,6 +2124,10 @@ def head(self, n=None):
2124
2124
return rs [0 ] if rs else None
2125
2125
return self .take (n )
2126
2126
2127
+ def first (self ):
2128
+ """ Return the first row. """
2129
+ return self .head ()
2130
+
2127
2131
def tail (self ):
2128
2132
raise NotImplemented
2129
2133
@@ -2159,7 +2163,7 @@ def select(self, *cols):
2159
2163
else :
2160
2164
cols = [c ._jc for c in cols ]
2161
2165
jcols = ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2162
- jdf = self ._jdf .select (self ._jdf . toColumnArray (jcols ))
2166
+ jdf = self ._jdf .select (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
2163
2167
return DataFrame (jdf , self .sql_ctx )
2164
2168
2165
2169
def filter (self , condition ):
@@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
2189
2193
else :
2190
2194
cols = [c ._jc for c in cols ]
2191
2195
jcols = ListConverter ().convert (cols , self ._sc ._gateway ._gateway_client )
2192
- jdf = self ._jdf .groupBy (self ._jdf . toColumnArray (jcols ))
2196
+ jdf = self ._jdf .groupBy (self .sql_ctx . _sc . _jvm . Dsl . toColumns (jcols ))
2193
2197
return GroupedDataFrame (jdf , self .sql_ctx )
2194
2198
2195
2199
def agg (self , * exprs ):
@@ -2278,14 +2282,17 @@ def agg(self, *exprs):
2278
2282
:param exprs: list or aggregate columns or a map from column
2279
2283
name to agregate methods.
2280
2284
"""
2285
+ assert exprs , "exprs should not be empty"
2281
2286
if len (exprs ) == 1 and isinstance (exprs [0 ], dict ):
2282
2287
jmap = MapConverter ().convert (exprs [0 ],
2283
2288
self .sql_ctx ._sc ._gateway ._gateway_client )
2284
2289
jdf = self ._jdf .agg (jmap )
2285
2290
else :
2286
2291
# Columns
2287
- assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Columns"
2288
- jdf = self ._jdf .agg (* exprs )
2292
+ assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Column"
2293
+ jcols = ListConverter ().convert ([c ._jc for c in exprs [1 :]],
2294
+ self .sql_ctx ._sc ._gateway ._gateway_client )
2295
+ jdf = self ._jdf .agg (exprs [0 ]._jc , self .sql_ctx ._sc ._jvm .Dsl .toColumns (jcols ))
2289
2296
return DataFrame (jdf , self .sql_ctx )
2290
2297
2291
2298
@dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
2347
2354
2348
2355
def _create_column_from_name (name ):
2349
2356
sc = SparkContext ._active_spark_context
2350
- return sc ._jvm .Column (name )
2357
+ return sc ._jvm .IncomputableColumn (name )
2351
2358
2352
2359
2353
2360
def _scalaMethod (name ):
@@ -2371,7 +2378,7 @@ def _(self):
2371
2378
return _
2372
2379
2373
2380
2374
- def _bin_op (name , pass_literal_through = False ):
2381
+ def _bin_op (name , pass_literal_through = True ):
2375
2382
""" Create a method for given binary operator
2376
2383
2377
2384
Keyword arguments:
@@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
2465
2472
# __getattr__ = _bin_op("getField")
2466
2473
2467
2474
# string methods
2468
- rlike = _bin_op ("rlike" , pass_literal_through = True )
2469
- like = _bin_op ("like" , pass_literal_through = True )
2470
- startswith = _bin_op ("startsWith" , pass_literal_through = True )
2471
- endswith = _bin_op ("endsWith" , pass_literal_through = True )
2475
+ rlike = _bin_op ("rlike" )
2476
+ like = _bin_op ("like" )
2477
+ startswith = _bin_op ("startsWith" )
2478
+ endswith = _bin_op ("endsWith" )
2472
2479
upper = _unary_op ("upper" )
2473
2480
lower = _unary_op ("lower" )
2474
2481
2475
2482
def substr (self , startPos , pos ):
2476
2483
if type (startPos ) != type (pos ):
2477
2484
raise TypeError ("Can not mix the type" )
2478
2485
if isinstance (startPos , (int , long )):
2479
-
2480
2486
jc = self ._jc .substr (startPos , pos )
2481
2487
elif isinstance (startPos , Column ):
2482
2488
jc = self ._jc .substr (startPos ._jc , pos ._jc )
@@ -2507,30 +2513,53 @@ def cast(self, dataType):
2507
2513
return Column (self ._jc .cast (jdt ), self ._jdf , self .sql_ctx )
2508
2514
2509
2515
2516
+ def _to_java_column (col ):
2517
+ if isinstance (col , Column ):
2518
+ jcol = col ._jc
2519
+ else :
2520
+ jcol = _create_column_from_name (col )
2521
+ return jcol
2522
+
2523
+
2510
2524
def _aggregate_func (name ):
2511
2525
""" Create a function for aggregator by name"""
2512
2526
def _ (col ):
2513
2527
sc = SparkContext ._active_spark_context
2514
- if isinstance (col , Column ):
2515
- jcol = col ._jc
2516
- else :
2517
- jcol = _create_column_from_name (col )
2518
- jc = getattr (sc ._jvm .org .apache .spark .sql .Dsl , name )(jcol )
2528
+ jc = getattr (sc ._jvm .Dsl , name )(_to_java_column (col ))
2519
2529
return Column (jc )
2530
+
2520
2531
return staticmethod (_ )
2521
2532
2522
2533
2523
2534
class Aggregator (object ):
2524
2535
"""
2525
2536
A collections of builtin aggregators
2526
2537
"""
2527
- max = _aggregate_func ("max" )
2528
- min = _aggregate_func ("min" )
2529
- avg = mean = _aggregate_func ("mean" )
2530
- sum = _aggregate_func ("sum" )
2531
- first = _aggregate_func ("first" )
2532
- last = _aggregate_func ("last" )
2533
- count = _aggregate_func ("count" )
2538
+ AGGS = [
2539
+ 'lit' , 'col' , 'column' , 'upper' , 'lower' , 'sqrt' , 'abs' ,
2540
+ 'min' , 'max' , 'first' , 'last' , 'count' , 'avg' , 'mean' , 'sum' , 'sumDistinct' ,
2541
+ ]
2542
+ for _name in AGGS :
2543
+ locals ()[_name ] = _aggregate_func (_name )
2544
+ del _name
2545
+
2546
+ @staticmethod
2547
+ def countDistinct (col , * cols ):
2548
+ sc = SparkContext ._active_spark_context
2549
+ jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
2550
+ sc ._gateway ._gateway_client )
2551
+ jc = sc ._jvm .Dsl .countDistinct (_to_java_column (col ),
2552
+ sc ._jvm .Dsl .toColumns (jcols ))
2553
+ return Column (jc )
2554
+
2555
+ @staticmethod
2556
+ def approxCountDistinct (col , rsd = None ):
2557
+ sc = SparkContext ._active_spark_context
2558
+ if rsd is None :
2559
+ jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ))
2560
+ else :
2561
+ jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ), rsd )
2562
+ return Column (jc )
2534
2563
2535
2564
2536
2565
def _test ():
0 commit comments