Skip to content

Commit bcc6668

Browse files
author
Davies Liu
committed
add reader amd writer API in Python
1 parent 3399055 commit bcc6668

File tree

5 files changed

+387
-37
lines changed

5 files changed

+387
-37
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ private[spark] object PythonUtils {
5050
/**
5151
* Convert list of T into seq of T (for calling API with varargs)
5252
*/
53-
def toSeq[T](cols: JList[T]): Seq[T] = {
54-
cols.toList.toSeq
53+
def toSeq[T](vs: JList[T]): Seq[T] = {
54+
vs.toList.toSeq
55+
}
56+
57+
/**
58+
* Convert list of T into array of T (for calling API with array)
59+
*/
60+
def toArray[T](vs: JList[T]): Array[T] = {
61+
vs.toArray().asInstanceOf[Array[T]]
5562
}
5663

5764
/**

python/pyspark/sql/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from pyspark.sql.column import Column
5959
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
6060
from pyspark.sql.group import GroupedData
61+
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
6162

6263
__all__ = [
6364
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',

python/pyspark/sql/context.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
3232
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
3333
from pyspark.sql.dataframe import DataFrame
34+
from pyspark.sql.readwriter import DataFrameReader
3435

3536
try:
3637
import pandas
@@ -546,6 +547,18 @@ def clearCache(self):
546547
"""Removes all cached tables from the in-memory cache. """
547548
self._ssql_ctx.clearCache()
548549

550+
@property
551+
def read(self):
552+
"""
553+
Returns a :class:`DataFrameReader` that can be used to read data
554+
in as a :class:`DataFrame`.
555+
556+
::note: Experimental
557+
558+
>>> sqlContext.read
559+
"""
560+
return DataFrameReader(self)
561+
549562

550563
class HiveContext(SQLContext):
551564
"""A variant of Spark SQL that integrates with data stored in Hive.

python/pyspark/sql/dataframe.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
3030
from pyspark.storagelevel import StorageLevel
3131
from pyspark.traceback_utils import SCCallSiteSync
32-
from pyspark.sql.types import *
3332
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
3433
from pyspark.sql.column import Column, _to_seq, _to_java_column
34+
from pyspark.sql.readwriter import DataFrameWriter
35+
from pyspark.sql.types import *
3536

3637
__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
3738

@@ -151,25 +152,6 @@ def insertInto(self, tableName, overwrite=False):
151152
"""
152153
self._jdf.insertInto(tableName, overwrite)
153154

154-
def _java_save_mode(self, mode):
155-
"""Returns the Java save mode based on the Python save mode represented by a string.
156-
"""
157-
jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
158-
jmode = jSaveMode.ErrorIfExists
159-
mode = mode.lower()
160-
if mode == "append":
161-
jmode = jSaveMode.Append
162-
elif mode == "overwrite":
163-
jmode = jSaveMode.Overwrite
164-
elif mode == "ignore":
165-
jmode = jSaveMode.Ignore
166-
elif mode == "error":
167-
pass
168-
else:
169-
raise ValueError(
170-
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
171-
return jmode
172-
173155
def saveAsTable(self, tableName, source=None, mode="error", **options):
174156
"""Saves the contents of this :class:`DataFrame` to a data source as a table.
175157
@@ -185,11 +167,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options):
185167
* `error`: Throw an exception if data already exists.
186168
* `ignore`: Silently ignore this operation if data already exists.
187169
"""
188-
if source is None:
189-
source = self.sql_ctx.getConf("spark.sql.sources.default",
190-
"org.apache.spark.sql.parquet")
191-
jmode = self._java_save_mode(mode)
192-
self._jdf.saveAsTable(tableName, source, jmode, options)
170+
self.write.saveAsTable(tableName, source, mode, **options)
193171

194172
def save(self, path=None, source=None, mode="error", **options):
195173
"""Saves the contents of the :class:`DataFrame` to a data source.
@@ -206,13 +184,17 @@ def save(self, path=None, source=None, mode="error", **options):
206184
* `error`: Throw an exception if data already exists.
207185
* `ignore`: Silently ignore this operation if data already exists.
208186
"""
209-
if path is not None:
210-
options["path"] = path
211-
if source is None:
212-
source = self.sql_ctx.getConf("spark.sql.sources.default",
213-
"org.apache.spark.sql.parquet")
214-
jmode = self._java_save_mode(mode)
215-
self._jdf.save(source, jmode, options)
187+
return self.write.save(path, source, mode, **options)
188+
189+
@property
190+
def write(self):
191+
"""
192+
Interface for saving the content of the :class:`DataFrame` out
193+
into external storage.
194+
195+
:return :class:`DataFrameWriter`
196+
"""
197+
return DataFrameWriter(self)
216198

217199
@property
218200
def schema(self):
@@ -411,9 +393,19 @@ def unpersist(self, blocking=True):
411393
self._jdf.unpersist(blocking)
412394
return self
413395

414-
# def coalesce(self, numPartitions, shuffle=False):
415-
# rdd = self._jdf.coalesce(numPartitions, shuffle, None)
416-
# return DataFrame(rdd, self.sql_ctx)
396+
def coalesce(self, numPartitions):
397+
"""
398+
Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
399+
400+
Similar to coalesce defined on an :class:`RDD`, this operation results in a
401+
narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
402+
there will not be a shuffle, instead each of the 100 new partitions will
403+
claim 10 of the current partitions.
404+
405+
>>> df.coalesce(1).rdd.getNumPartitions()
406+
1
407+
"""
408+
return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
417409

418410
def repartition(self, numPartitions):
419411
"""Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.

0 commit comments

Comments
 (0)