Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 5f68bc8

Browse files
author
Davies Liu
committed
update tests
1 parent 6437e9a commit 5f68bc8

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

python/pyspark/sql/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def read(self):
557557
::note: Experimental
558558
559559
>>> sqlContext.read
560+
<pyspark.sql.readwriter.DataFrameReader object at ...>
560561
"""
561562
return DataFrameReader(self)
562563

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ def write(self):
193193
into external storage.
194194
195195
:return :class:`DataFrameWriter`
196+
197+
::note: Experimental
198+
199+
>>> df.write
200+
<pyspark.sql.readwriter.DataFrameWriter object at ...>
196201
"""
197202
return DataFrameWriter(self)
198203

python/pyspark/sql/readwriter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def load(self, path=None, format=None, schema=None, **options):
5454
if schema is not None:
5555
if not isinstance(schema, StructType):
5656
raise TypeError("schema should be StructType")
57-
jschema = self.sqlContext._ssql_ctx.parseDataType(schema.json())
57+
jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
5858
jreader = jreader.schema(jschema)
5959
for k in options:
6060
jreader = jreader.option(k, options[k])
@@ -79,7 +79,7 @@ def json(self, path, schema=None):
7979
>>> shutil.rmtree(jsonFile)
8080
>>> with open(jsonFile, 'w') as f:
8181
... f.writelines(jsonStrings)
82-
>>> df1 = sqlContext.jsonFile(jsonFile)
82+
>>> df1 = sqlContext.read.json(jsonFile)
8383
>>> df1.printSchema()
8484
root
8585
|-- field1: long (nullable = true)
@@ -92,7 +92,7 @@ def json(self, path, schema=None):
9292
... StructField("field2", StringType()),
9393
... StructField("field3",
9494
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
95-
>>> df2 = sqlContext.jsonFile(jsonFile, schema)
95+
>>> df2 = sqlContext.read.json(jsonFile, schema)
9696
>>> df2.printSchema()
9797
root
9898
|-- field2: string (nullable = true)
@@ -103,7 +103,7 @@ def json(self, path, schema=None):
103103
if schema is None:
104104
jdf = self._jreader.json(path)
105105
else:
106-
jschema = self.sqlContext._ssql_ctx.parseDataType(schema.json())
106+
jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
107107
jdf = self._jreader.schema(jschema).json(path)
108108
return self._df(jdf)
109109

@@ -181,7 +181,8 @@ class DataFrameWriter(object):
181181
"""
182182
def __init__(self, df):
183183
self._df = df
184-
self._jwrite = df._df.write()
184+
self._sqlContext = df.sql_ctx
185+
self._jwrite = df._jdf.write()
185186

186187
def save(self, path=None, format=None, mode="error", **options):
187188
"""

python/pyspark/sql/tests.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -480,29 +480,29 @@ def test_save_and_load(self):
480480
df = self.df
481481
tmpPath = tempfile.mkdtemp()
482482
shutil.rmtree(tmpPath)
483-
df.save(tmpPath, "org.apache.spark.sql.json", "error")
484-
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
485-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
483+
df.write.json(tmpPath)
484+
actual = self.sqlCtx.read.json(tmpPath)
485+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
486486

487487
schema = StructType([StructField("value", StringType(), True)])
488-
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
489-
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
488+
actual = self.sqlCtx.read.json(tmpPath, schema)
489+
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
490490

491-
df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
492-
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
493-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
491+
df.write.json(tmpPath, "overwrite")
492+
actual = self.sqlCtx.read.json(tmpPath)
493+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
494494

495-
df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
496-
noUse="this options will not be used in save.")
497-
actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
498-
noUse="this options will not be used in load.")
499-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
495+
df.write.save(format="json", mode="overwrite", path=tmpPath,
496+
noUse="this options will not be used in save.")
497+
actual = self.sqlCtx.read.load(format="json", path=tmpPath,
498+
noUse="this options will not be used in load.")
499+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
500500

501501
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
502502
"org.apache.spark.sql.parquet")
503503
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
504504
actual = self.sqlCtx.load(path=tmpPath)
505-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
505+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
506506
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
507507

508508
shutil.rmtree(tmpPath)

0 commit comments

Comments
 (0)