@@ -762,51 +762,44 @@ def test_save_and_load_table(self):
762
762
df = self .df
763
763
tmpPath = tempfile .mkdtemp ()
764
764
shutil .rmtree (tmpPath )
765
- df .saveAsTable ("savedJsonTable" , "org.apache.spark.sql.json" , "append" , path = tmpPath )
766
- actual = self .sqlCtx .createExternalTable ("externalJsonTable" , tmpPath ,
767
- "org.apache.spark.sql.json" )
768
- self .assertTrue (
769
- sorted (df .collect ()) ==
770
- sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
771
- self .assertTrue (
772
- sorted (df .collect ()) ==
773
- sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
774
- self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
765
+ df .write .saveAsTable ("savedJsonTable" , "json" , "append" , path = tmpPath )
766
+ actual = self .sqlCtx .createExternalTable ("externalJsonTable" , tmpPath , "json" )
767
+ self .assertEqual (sorted (df .collect ()),
768
+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
769
+ self .assertEqual (sorted (df .collect ()),
770
+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
771
+ self .assertEqual (sorted (df .collect ()), sorted (actual .collect ()))
775
772
self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
776
773
777
- df .saveAsTable ("savedJsonTable" , "org.apache.spark.sql. json" , "overwrite" , path = tmpPath )
774
+ df .write . saveAsTable ("savedJsonTable" , "json" , "overwrite" , path = tmpPath )
778
775
schema = StructType ([StructField ("value" , StringType (), True )])
779
- actual = self .sqlCtx .createExternalTable ("externalJsonTable" ,
780
- source = "org.apache.spark.sql.json" ,
776
+ actual = self .sqlCtx .createExternalTable ("externalJsonTable" , source = "json" ,
781
777
schema = schema , path = tmpPath ,
782
778
noUse = "this options will not be used" )
783
- self .assertTrue (
784
- sorted (df .collect ()) ==
785
- sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
786
- self .assertTrue (
787
- sorted (df .select ("value" ).collect ()) ==
788
- sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
789
- self .assertTrue (sorted (df .select ("value" ).collect ()) == sorted (actual .collect ()))
779
+ self .assertEqual (sorted (df .collect ()),
780
+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
781
+ self .assertEqual (sorted (df .select ("value" ).collect ()),
782
+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
783
+ self .assertEqual (sorted (df .select ("value" ).collect ()), sorted (actual .collect ()))
790
784
self .sqlCtx .sql ("DROP TABLE savedJsonTable" )
791
785
self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
792
786
793
787
defaultDataSourceName = self .sqlCtx .getConf ("spark.sql.sources.default" ,
794
788
"org.apache.spark.sql.parquet" )
795
789
self .sqlCtx .sql ("SET spark.sql.sources.default=org.apache.spark.sql.json" )
796
- df .saveAsTable ("savedJsonTable" , path = tmpPath , mode = "overwrite" )
790
+ df .write . saveAsTable ("savedJsonTable" , path = tmpPath , mode = "overwrite" )
797
791
actual = self .sqlCtx .createExternalTable ("externalJsonTable" , path = tmpPath )
798
- self .assertTrue (
799
- sorted (df .collect ()) ==
800
- sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
801
- self .assertTrue (
802
- sorted (df .collect ()) ==
803
- sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
804
- self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
792
+ self .assertEqual (sorted (df .collect ()),
793
+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
794
+ self .assertEqual (sorted (df .collect ()),
795
+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
796
+ self .assertEqual (sorted (df .collect ()), sorted (actual .collect ()))
805
797
self .sqlCtx .sql ("DROP TABLE savedJsonTable" )
806
798
self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
807
799
self .sqlCtx .sql ("SET spark.sql.sources.default=" + defaultDataSourceName )
808
800
809
801
shutil .rmtree (tmpPath )
810
802
803
+
811
804
if __name__ == "__main__" :
812
805
unittest .main ()
0 commit comments