Skip to content

Commit c7200eb

Browse files
author
Davies Liu
committed
update tests
1 parent 9cbf01b commit c7200eb

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

python/pyspark/sql/tests.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -762,51 +762,44 @@ def test_save_and_load_table(self):
762762
df = self.df
763763
tmpPath = tempfile.mkdtemp()
764764
shutil.rmtree(tmpPath)
765-
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
766-
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
767-
"org.apache.spark.sql.json")
768-
self.assertTrue(
769-
sorted(df.collect()) ==
770-
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
771-
self.assertTrue(
772-
sorted(df.collect()) ==
773-
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
774-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
765+
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
766+
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
767+
self.assertEqual(sorted(df.collect()),
768+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
769+
self.assertEqual(sorted(df.collect()),
770+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
771+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
775772
self.sqlCtx.sql("DROP TABLE externalJsonTable")
776773

777-
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
774+
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
778775
schema = StructType([StructField("value", StringType(), True)])
779-
actual = self.sqlCtx.createExternalTable("externalJsonTable",
780-
source="org.apache.spark.sql.json",
776+
actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
781777
schema=schema, path=tmpPath,
782778
noUse="this options will not be used")
783-
self.assertTrue(
784-
sorted(df.collect()) ==
785-
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
786-
self.assertTrue(
787-
sorted(df.select("value").collect()) ==
788-
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
789-
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
779+
self.assertEqual(sorted(df.collect()),
780+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
781+
self.assertEqual(sorted(df.select("value").collect()),
782+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
783+
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
790784
self.sqlCtx.sql("DROP TABLE savedJsonTable")
791785
self.sqlCtx.sql("DROP TABLE externalJsonTable")
792786

793787
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
794788
"org.apache.spark.sql.parquet")
795789
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
796-
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
790+
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
797791
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
798-
self.assertTrue(
799-
sorted(df.collect()) ==
800-
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
801-
self.assertTrue(
802-
sorted(df.collect()) ==
803-
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
804-
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
792+
self.assertEqual(sorted(df.collect()),
793+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
794+
self.assertEqual(sorted(df.collect()),
795+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
796+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
805797
self.sqlCtx.sql("DROP TABLE savedJsonTable")
806798
self.sqlCtx.sql("DROP TABLE externalJsonTable")
807799
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
808800

809801
shutil.rmtree(tmpPath)
810802

803+
811804
if __name__ == "__main__":
812805
unittest.main()

0 commit comments

Comments
 (0)