Skip to content

[SPARK-7738] [SQL] [PySpark] add reader and writer API in Python #6238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ private[spark] object PythonUtils {
/**
* Convert list of T into seq of T (for calling API with varargs)
*/
def toSeq[T](cols: JList[T]): Seq[T] = {
cols.toList.toSeq
def toSeq[T](vs: JList[T]): Seq[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does vs mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value -> v, values => vs :)

vs.toList.toSeq
}

/**
* Convert list of T into array of T (for calling API with array)
*/
def toArray[T](vs: JList[T]): Array[T] = {
vs.toArray().asInstanceOf[Array[T]]
}

/**
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter

__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
Expand Down
28 changes: 15 additions & 13 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader

try:
import pandas
Expand Down Expand Up @@ -437,19 +438,7 @@ def load(self, path=None, source=None, schema=None, **options):

Optionally, a schema can be provided as the schema of the returned DataFrame.
"""
if path is not None:
options["path"] = path
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
if schema is None:
df = self._ssql_ctx.load(source, options)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.load(source, scala_datatype, options)
return DataFrame(df, self)
return self.read.load(path, source, schema, **options)

def createExternalTable(self, tableName, path=None, source=None,
schema=None, **options):
Expand Down Expand Up @@ -547,6 +536,19 @@ def clearCache(self):
"""Removes all cached tables from the in-memory cache. """
self._ssql_ctx.clearCache()

@property
def read(self):
"""
Returns a :class:`DataFrameReader` that can be used to read data
in as a :class:`DataFrame`.

::note: Experimental

>>> sqlContext.read
<pyspark.sql.readwriter.DataFrameReader object at ...>
"""
return DataFrameReader(self)


class HiveContext(SQLContext):
"""A variant of Spark SQL that integrates with data stored in Hive.
Expand Down
67 changes: 32 additions & 35 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.types import *

__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]

Expand Down Expand Up @@ -151,25 +152,6 @@ def insertInto(self, tableName, overwrite=False):
"""
self._jdf.insertInto(tableName, overwrite)

def _java_save_mode(self, mode):
"""Returns the Java save mode based on the Python save mode represented by a string.
"""
jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
jmode = jSaveMode.ErrorIfExists
mode = mode.lower()
if mode == "append":
jmode = jSaveMode.Append
elif mode == "overwrite":
jmode = jSaveMode.Overwrite
elif mode == "ignore":
jmode = jSaveMode.Ignore
elif mode == "error":
pass
else:
raise ValueError(
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
return jmode

def saveAsTable(self, tableName, source=None, mode="error", **options):
"""Saves the contents of this :class:`DataFrame` to a data source as a table.

Expand All @@ -185,11 +167,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options):
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
"""
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
self._jdf.saveAsTable(tableName, source, jmode, options)
self.write.saveAsTable(tableName, source, mode, **options)

def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
Expand All @@ -206,13 +184,22 @@ def save(self, path=None, source=None, mode="error", **options):
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
"""
if path is not None:
options["path"] = path
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
self._jdf.save(source, jmode, options)
return self.write.save(path, source, mode, **options)

@property
def write(self):
"""
Interface for saving the content of the :class:`DataFrame` out
into external storage.

:return :class:`DataFrameWriter`

::note: Experimental

>>> df.write
<pyspark.sql.readwriter.DataFrameWriter object at ...>
"""
return DataFrameWriter(self)

@property
def schema(self):
Expand Down Expand Up @@ -411,9 +398,19 @@ def unpersist(self, blocking=True):
self._jdf.unpersist(blocking)
return self

# def coalesce(self, numPartitions, shuffle=False):
# rdd = self._jdf.coalesce(numPartitions, shuffle, None)
# return DataFrame(rdd, self.sql_ctx)
def coalesce(self, numPartitions):
"""
Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.

Similar to coalesce defined on an :class:`RDD`, this operation results in a
narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
there will not be a shuffle, instead each of the 100 new partitions will
claim 10 of the current partitions.

>>> df.coalesce(1).rdd.getNumPartitions()
1
"""
return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)

def repartition(self, numPartitions):
"""Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.
Expand Down
Loading