Skip to content

Commit 885d162

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd
Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu <[email protected]> Closes apache#2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd
1 parent 71af030 commit 885d162

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

python/pyspark/sql.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ def applySchema(self, rdd, schema):
11221122
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
11231123
jrdd = self._pythonToJava(rdd._jrdd, batched)
11241124
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
1125-
return SchemaRDD(srdd, self)
1125+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
11261126

11271127
def registerRDDAsTable(self, rdd, tableName):
11281128
"""Registers the given RDD as a temporary table in the catalog.
@@ -1134,8 +1134,8 @@ def registerRDDAsTable(self, rdd, tableName):
11341134
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
11351135
"""
11361136
if (rdd.__class__ is SchemaRDD):
1137-
jschema_rdd = rdd._jschema_rdd
1138-
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
1137+
srdd = rdd._jschema_rdd.baseSchemaRDD()
1138+
self._ssql_ctx.registerRDDAsTable(srdd, tableName)
11391139
else:
11401140
raise ValueError("Can only register SchemaRDD as table")
11411141

@@ -1151,7 +1151,7 @@ def parquetFile(self, path):
11511151
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
11521152
True
11531153
"""
1154-
jschema_rdd = self._ssql_ctx.parquetFile(path)
1154+
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
11551155
return SchemaRDD(jschema_rdd, self)
11561156

11571157
def jsonFile(self, path, schema=None):
@@ -1207,11 +1207,11 @@ def jsonFile(self, path, schema=None):
12071207
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
12081208
"""
12091209
if schema is None:
1210-
jschema_rdd = self._ssql_ctx.jsonFile(path)
1210+
srdd = self._ssql_ctx.jsonFile(path)
12111211
else:
12121212
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1213-
jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
1214-
return SchemaRDD(jschema_rdd, self)
1213+
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
1214+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
12151215

12161216
def jsonRDD(self, rdd, schema=None):
12171217
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1275,11 +1275,11 @@ def func(iterator):
12751275
keyed._bypass_serializer = True
12761276
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
12771277
if schema is None:
1278-
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
1278+
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
12791279
else:
12801280
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1281-
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
1282-
return SchemaRDD(jschema_rdd, self)
1281+
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
1282+
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
12831283

12841284
def sql(self, sqlQuery):
12851285
"""Return a L{SchemaRDD} representing the result of the given query.
@@ -1290,7 +1290,7 @@ def sql(self, sqlQuery):
12901290
>>> srdd2.collect()
12911291
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
12921292
"""
1293-
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
1293+
return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
12941294

12951295
def table(self, tableName):
12961296
"""Returns the specified table as a L{SchemaRDD}.
@@ -1301,7 +1301,7 @@ def table(self, tableName):
13011301
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
13021302
True
13031303
"""
1304-
return SchemaRDD(self._ssql_ctx.table(tableName), self)
1304+
return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
13051305

13061306
def cacheTable(self, tableName):
13071307
"""Caches the specified table in-memory."""
@@ -1353,7 +1353,7 @@ def hiveql(self, hqlQuery):
13531353
warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
13541354
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
13551355
DeprecationWarning)
1356-
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
1356+
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
13571357

13581358
def hql(self, hqlQuery):
13591359
"""
@@ -1524,6 +1524,8 @@ class SchemaRDD(RDD):
15241524
def __init__(self, jschema_rdd, sql_ctx):
15251525
self.sql_ctx = sql_ctx
15261526
self._sc = sql_ctx._sc
1527+
clsName = jschema_rdd.getClass().getName()
1528+
assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
15271529
self._jschema_rdd = jschema_rdd
15281530
self._id = None
15291531
self.is_cached = False
@@ -1540,7 +1542,7 @@ def _jrdd(self):
15401542
L{pyspark.rdd.RDD} super class (map, filter, etc.).
15411543
"""
15421544
if not hasattr(self, '_lazy_jrdd'):
1543-
self._lazy_jrdd = self._jschema_rdd.javaToPython()
1545+
self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
15441546
return self._lazy_jrdd
15451547

15461548
def id(self):
@@ -1598,7 +1600,7 @@ def saveAsTable(self, tableName):
15981600
def schema(self):
15991601
"""Returns the schema of this SchemaRDD (represented by
16001602
a L{StructType})."""
1601-
return _parse_datatype_string(self._jschema_rdd.schema().toString())
1603+
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
16021604

16031605
def schemaString(self):
16041606
"""Returns the output schema in the tree format."""
@@ -1649,8 +1651,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
16491651
rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
16501652

16511653
schema = self.schema()
1652-
import pickle
1653-
pickle.loads(pickle.dumps(schema))
16541654

16551655
def applySchema(_, it):
16561656
cls = _create_cls(schema)
@@ -1687,10 +1687,8 @@ def isCheckpointed(self):
16871687

16881688
def getCheckpointFile(self):
16891689
checkpointFile = self._jschema_rdd.getCheckpointFile()
1690-
if checkpointFile.isDefined():
1690+
if checkpointFile.isPresent():
16911691
return checkpointFile.get()
1692-
else:
1693-
return None
16941692

16951693
def coalesce(self, numPartitions, shuffle=False):
16961694
rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)

python/pyspark/tests.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,34 @@ def test_broadcast_in_udf(self):
607607
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
608608
self.assertEqual("", res[0])
609609

610+
def test_basic_functions(self):
611+
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
612+
srdd = self.sqlCtx.jsonRDD(rdd)
613+
srdd.count()
614+
srdd.collect()
615+
srdd.schemaString()
616+
srdd.schema()
617+
618+
# cache and checkpoint
619+
self.assertFalse(srdd.is_cached)
620+
srdd.persist()
621+
srdd.unpersist()
622+
srdd.cache()
623+
self.assertTrue(srdd.is_cached)
624+
self.assertFalse(srdd.isCheckpointed())
625+
self.assertEqual(None, srdd.getCheckpointFile())
626+
627+
srdd = srdd.coalesce(2, True)
628+
srdd = srdd.repartition(3)
629+
srdd = srdd.distinct()
630+
srdd.intersection(srdd)
631+
self.assertEqual(2, srdd.count())
632+
633+
srdd.registerTempTable("temp")
634+
srdd = self.sqlCtx.sql("select foo from temp")
635+
srdd.count()
636+
srdd.collect()
637+
610638

611639
class TestIO(PySparkTestCase):
612640

0 commit comments

Comments
 (0)