Skip to content

Commit 11e4e90

Browse files
zero323liyichao
authored andcommitted
[SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy
## What changes were proposed in this pull request? Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931)) ## How was this patch tested? Unit tests covering new feature. __Note__: Based on work of GregBowyer (f49b9a2) CC HyukjinKwon Author: zero323 <[email protected]> Author: Greg Bowyer <[email protected]> Closes apache#17077 from zero323/SPARK-16931.
1 parent 753d497 commit 11e4e90

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,63 @@ def partitionBy(self, *cols):
563563
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
564564
return self
565565

566+
@since(2.3)
567+
def bucketBy(self, numBuckets, col, *cols):
568+
"""Buckets the output by the given columns.If specified,
569+
the output is laid out on the file system similar to Hive's bucketing scheme.
570+
571+
:param numBuckets: the number of buckets to save
572+
:param col: a name of a column, or a list of names.
573+
:param cols: additional names (optional). If `col` is a list it should be empty.
574+
575+
.. note:: Applicable for file-based data sources in combination with
576+
:py:meth:`DataFrameWriter.saveAsTable`.
577+
578+
>>> (df.write.format('parquet')
579+
... .bucketBy(100, 'year', 'month')
580+
... .mode("overwrite")
581+
... .saveAsTable('bucketed_table'))
582+
"""
583+
if not isinstance(numBuckets, int):
584+
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))
585+
586+
if isinstance(col, (list, tuple)):
587+
if cols:
588+
raise ValueError("col is a {0} but cols are not empty".format(type(col)))
589+
590+
col, cols = col[0], col[1:]
591+
592+
if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
593+
raise TypeError("all names should be `str`")
594+
595+
self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols))
596+
return self
597+
598+
@since(2.3)
599+
def sortBy(self, col, *cols):
600+
"""Sorts the output in each bucket by the given columns on the file system.
601+
602+
:param col: a name of a column, or a list of names.
603+
:param cols: additional names (optional). If `col` is a list it should be empty.
604+
605+
>>> (df.write.format('parquet')
606+
... .bucketBy(100, 'year', 'month')
607+
... .sortBy('day')
608+
... .mode("overwrite")
609+
... .saveAsTable('sorted_bucketed_table'))
610+
"""
611+
if isinstance(col, (list, tuple)):
612+
if cols:
613+
raise ValueError("col is a {0} but cols are not empty".format(type(col)))
614+
615+
col, cols = col[0], col[1:]
616+
617+
if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
618+
raise TypeError("all names should be `str`")
619+
620+
self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols))
621+
return self
622+
566623
@since(1.4)
567624
def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
568625
"""Saves the contents of the :class:`DataFrame` to a data source.

python/pyspark/sql/tests.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self):
211211
sqlContext2 = SQLContext(self.sc)
212212
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
213213

214+
def tearDown(self):
215+
super(SQLTests, self).tearDown()
216+
217+
# tear down test_bucketed_write state
218+
self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket")
219+
214220
def test_row_should_be_read_only(self):
215221
row = Row(a=1, b=2)
216222
self.assertEqual(1, row.a)
@@ -2196,6 +2202,54 @@ def test_BinaryType_serialization(self):
21962202
df = self.spark.createDataFrame(data, schema=schema)
21972203
df.collect()
21982204

2205+
def test_bucketed_write(self):
2206+
data = [
2207+
(1, "foo", 3.0), (2, "foo", 5.0),
2208+
(3, "bar", -1.0), (4, "bar", 6.0),
2209+
]
2210+
df = self.spark.createDataFrame(data, ["x", "y", "z"])
2211+
2212+
def count_bucketed_cols(names, table="pyspark_bucket"):
2213+
"""Given a sequence of column names and a table name
2214+
query the catalog and return number o columns which are
2215+
used for bucketing
2216+
"""
2217+
cols = self.spark.catalog.listColumns(table)
2218+
num = len([c for c in cols if c.name in names and c.isBucket])
2219+
return num
2220+
2221+
# Test write with one bucketing column
2222+
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
2223+
self.assertEqual(count_bucketed_cols(["x"]), 1)
2224+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2225+
2226+
# Test write two bucketing columns
2227+
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
2228+
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
2229+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2230+
2231+
# Test write with bucket and sort
2232+
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
2233+
self.assertEqual(count_bucketed_cols(["x"]), 1)
2234+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2235+
2236+
# Test write with a list of columns
2237+
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
2238+
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
2239+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2240+
2241+
# Test write with bucket and sort with a list of columns
2242+
(df.write.bucketBy(2, "x")
2243+
.sortBy(["y", "z"])
2244+
.mode("overwrite").saveAsTable("pyspark_bucket"))
2245+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2246+
2247+
# Test write with bucket and sort with multiple columns
2248+
(df.write.bucketBy(2, "x")
2249+
.sortBy("y", "z")
2250+
.mode("overwrite").saveAsTable("pyspark_bucket"))
2251+
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
2252+
21992253

22002254
class HiveSparkSubmitTests(SparkSubmitTests):
22012255

0 commit comments

Comments
 (0)