@@ -1122,7 +1122,7 @@ def applySchema(self, rdd, schema):
1122
1122
batched = isinstance (rdd ._jrdd_deserializer , BatchedSerializer )
1123
1123
jrdd = self ._pythonToJava (rdd ._jrdd , batched )
1124
1124
srdd = self ._ssql_ctx .applySchemaToPythonRDD (jrdd .rdd (), str (schema ))
1125
- return SchemaRDD (srdd , self )
1125
+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
1126
1126
1127
1127
def registerRDDAsTable (self , rdd , tableName ):
1128
1128
"""Registers the given RDD as a temporary table in the catalog.
@@ -1134,8 +1134,8 @@ def registerRDDAsTable(self, rdd, tableName):
1134
1134
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
1135
1135
"""
1136
1136
if (rdd .__class__ is SchemaRDD ):
1137
- jschema_rdd = rdd ._jschema_rdd
1138
- self ._ssql_ctx .registerRDDAsTable (jschema_rdd , tableName )
1137
+ srdd = rdd ._jschema_rdd . baseSchemaRDD ()
1138
+ self ._ssql_ctx .registerRDDAsTable (srdd , tableName )
1139
1139
else :
1140
1140
raise ValueError ("Can only register SchemaRDD as table" )
1141
1141
@@ -1151,7 +1151,7 @@ def parquetFile(self, path):
1151
1151
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
1152
1152
True
1153
1153
"""
1154
- jschema_rdd = self ._ssql_ctx .parquetFile (path )
1154
+ jschema_rdd = self ._ssql_ctx .parquetFile (path ). toJavaSchemaRDD ()
1155
1155
return SchemaRDD (jschema_rdd , self )
1156
1156
1157
1157
def jsonFile (self , path , schema = None ):
@@ -1207,11 +1207,11 @@ def jsonFile(self, path, schema=None):
1207
1207
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
1208
1208
"""
1209
1209
if schema is None :
1210
- jschema_rdd = self ._ssql_ctx .jsonFile (path )
1210
+ srdd = self ._ssql_ctx .jsonFile (path )
1211
1211
else :
1212
1212
scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1213
- jschema_rdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1214
- return SchemaRDD (jschema_rdd , self )
1213
+ srdd = self ._ssql_ctx .jsonFile (path , scala_datatype )
1214
+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
1215
1215
1216
1216
def jsonRDD (self , rdd , schema = None ):
1217
1217
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1275,11 +1275,11 @@ def func(iterator):
1275
1275
keyed ._bypass_serializer = True
1276
1276
jrdd = keyed ._jrdd .map (self ._jvm .BytesToString ())
1277
1277
if schema is None :
1278
- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
1278
+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd ())
1279
1279
else :
1280
1280
scala_datatype = self ._ssql_ctx .parseDataType (str (schema ))
1281
- jschema_rdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1282
- return SchemaRDD (jschema_rdd , self )
1281
+ srdd = self ._ssql_ctx .jsonRDD (jrdd .rdd (), scala_datatype )
1282
+ return SchemaRDD (srdd . toJavaSchemaRDD () , self )
1283
1283
1284
1284
def sql (self , sqlQuery ):
1285
1285
"""Return a L{SchemaRDD} representing the result of the given query.
@@ -1290,7 +1290,7 @@ def sql(self, sqlQuery):
1290
1290
>>> srdd2.collect()
1291
1291
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
1292
1292
"""
1293
- return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ), self )
1293
+ return SchemaRDD (self ._ssql_ctx .sql (sqlQuery ). toJavaSchemaRDD () , self )
1294
1294
1295
1295
def table (self , tableName ):
1296
1296
"""Returns the specified table as a L{SchemaRDD}.
@@ -1301,7 +1301,7 @@ def table(self, tableName):
1301
1301
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
1302
1302
True
1303
1303
"""
1304
- return SchemaRDD (self ._ssql_ctx .table (tableName ), self )
1304
+ return SchemaRDD (self ._ssql_ctx .table (tableName ). toJavaSchemaRDD () , self )
1305
1305
1306
1306
def cacheTable (self , tableName ):
1307
1307
"""Caches the specified table in-memory."""
@@ -1353,7 +1353,7 @@ def hiveql(self, hqlQuery):
1353
1353
warnings .warn ("hiveql() is deprecated as the sql function now parses using HiveQL by" +
1354
1354
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'" ,
1355
1355
DeprecationWarning )
1356
- return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ), self )
1356
+ return SchemaRDD (self ._ssql_ctx .hiveql (hqlQuery ). toJavaSchemaRDD () , self )
1357
1357
1358
1358
def hql (self , hqlQuery ):
1359
1359
"""
@@ -1524,6 +1524,8 @@ class SchemaRDD(RDD):
1524
1524
def __init__ (self , jschema_rdd , sql_ctx ):
1525
1525
self .sql_ctx = sql_ctx
1526
1526
self ._sc = sql_ctx ._sc
1527
+ clsName = jschema_rdd .getClass ().getName ()
1528
+ assert clsName .endswith ("JavaSchemaRDD" ), "jschema_rdd must be JavaSchemaRDD"
1527
1529
self ._jschema_rdd = jschema_rdd
1528
1530
self ._id = None
1529
1531
self .is_cached = False
@@ -1540,7 +1542,7 @@ def _jrdd(self):
1540
1542
L{pyspark.rdd.RDD} super class (map, filter, etc.).
1541
1543
"""
1542
1544
if not hasattr (self , '_lazy_jrdd' ):
1543
- self ._lazy_jrdd = self ._jschema_rdd .javaToPython ()
1545
+ self ._lazy_jrdd = self ._jschema_rdd .baseSchemaRDD (). javaToPython ()
1544
1546
return self ._lazy_jrdd
1545
1547
1546
1548
def id (self ):
@@ -1598,7 +1600,7 @@ def saveAsTable(self, tableName):
1598
1600
def schema (self ):
1599
1601
"""Returns the schema of this SchemaRDD (represented by
1600
1602
a L{StructType})."""
1601
- return _parse_datatype_string (self ._jschema_rdd .schema ().toString ())
1603
+ return _parse_datatype_string (self ._jschema_rdd .baseSchemaRDD (). schema ().toString ())
1602
1604
1603
1605
def schemaString (self ):
1604
1606
"""Returns the output schema in the tree format."""
@@ -1649,8 +1651,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
1649
1651
rdd = RDD (self ._jrdd , self ._sc , self ._jrdd_deserializer )
1650
1652
1651
1653
schema = self .schema ()
1652
- import pickle
1653
- pickle .loads (pickle .dumps (schema ))
1654
1654
1655
1655
def applySchema (_ , it ):
1656
1656
cls = _create_cls (schema )
@@ -1687,10 +1687,8 @@ def isCheckpointed(self):
1687
1687
1688
1688
def getCheckpointFile (self ):
1689
1689
checkpointFile = self ._jschema_rdd .getCheckpointFile ()
1690
- if checkpointFile .isDefined ():
1690
+ if checkpointFile .isPresent ():
1691
1691
return checkpointFile .get ()
1692
- else :
1693
- return None
1694
1692
1695
1693
def coalesce (self , numPartitions , shuffle = False ):
1696
1694
rdd = self ._jschema_rdd .coalesce (numPartitions , shuffle )
0 commit comments