Skip to content

Commit 1d9c13a

Browse files
committed
Update applySchema API.
1 parent 85e9b51 commit 1d9c13a

File tree

3 files changed

+34
-52
lines changed

3 files changed

+34
-52
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
8888
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd)))
8989

9090
/**
91-
* Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function
92-
* that will be applied to each partition of the RDD to convert RDD records to [[Row]]s.
91+
* :: DeveloperApi ::
92+
* Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
93+
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
94+
* the provided schema. Otherwise, there will be runtime exception.
9395
*
9496
* @group userf
9597
*/
96-
def applySchema[A](rdd: RDD[A], schema: StructType, f: A => Row): SchemaRDD =
97-
applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f))
98-
99-
/**
100-
* Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function
101-
* that will be applied to each partition of the RDD to convert RDD records to [[Row]]s.
102-
* Similar to `RDD.mapPartitions``, this function can be used to improve performance where there
103-
* is other setup work that can be amortized and used repeatedly for all of the
104-
* elements in a partition.
105-
* @group userf
106-
*/
107-
def applySchemaToPartitions[A](
108-
rdd: RDD[A],
109-
schema: StructType,
110-
f: Iterator[A] => Iterator[Row]): SchemaRDD =
111-
new SchemaRDD(this, makeCustomRDDScan(rdd, schema, f))
112-
113-
protected[sql] def makeCustomRDDScan[A](
114-
rdd: RDD[A],
115-
schema: StructType,
116-
f: Iterator[A] => Iterator[Row]): LogicalPlan =
117-
SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.mapPartitions(f)))
98+
@DeveloperApi
99+
def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
100+
val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))
101+
new SchemaRDD(this, logicalPlan)
102+
}
118103

119104
/**
120105
* Loads a Parquet file, returning the result as a [[SchemaRDD]].
@@ -133,11 +118,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
133118
def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0)
134119

135120
/**
121+
* :: Experimental ::
136122
* Loads a JSON file (one object per line) and applies the given schema,
137123
* returning the result as a [[SchemaRDD]].
138124
*
139125
* @group userf
140126
*/
127+
@Experimental
141128
def jsonFile(path: String, schema: StructType): SchemaRDD = {
142129
val json = sparkContext.textFile(path)
143130
jsonRDD(json, schema)
@@ -162,32 +149,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
162149
def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0)
163150

164151
/**
152+
* :: Experimental ::
165153
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
166154
* returning the result as a [[SchemaRDD]].
167155
*
168156
* @group userf
169157
*/
158+
@Experimental
170159
def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
171160
val appliedSchema =
172161
Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)))
173-
174-
applySchemaToPartitions(
175-
json,
176-
appliedSchema,
177-
JsonRDD.jsonStringToRow(appliedSchema, _: Iterator[String]))
162+
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
163+
applySchema(rowRDD, appliedSchema)
178164
}
179165

180166
/**
181167
* :: Experimental ::
182168
*/
183169
@Experimental
184170
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
185-
val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
186-
187-
applySchemaToPartitions(
188-
json,
189-
schema,
190-
JsonRDD.jsonStringToRow(schema, _: Iterator[String]))
171+
val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
172+
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
173+
applySchema(rowRDD, appliedSchema)
191174
}
192175

193176

sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@ import org.apache.spark.sql.Logging
3232
private[sql] object JsonRDD extends Logging {
3333

3434
private[sql] def jsonStringToRow(
35-
schema: StructType,
36-
jsonIter: Iterator[String]): Iterator[Row] = {
37-
parseJson(jsonIter).map(parsed => asRow(parsed, schema))
35+
json: RDD[String],
36+
schema: StructType): RDD[Row] = {
37+
parseJson(json).map(parsed => asRow(parsed, schema))
3838
}
3939

4040
private[sql] def inferSchema(
4141
json: RDD[String],
4242
samplingRatio: Double = 1.0): StructType = {
4343
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
4444
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
45-
val allKeys =
46-
schemaData.mapPartitions(iter => parseJson(iter)).map(allKeysWithValueTypes).reduce(_ ++ _)
45+
val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _)
4746
createSchema(allKeys)
4847
}
4948

@@ -255,7 +254,7 @@ private[sql] object JsonRDD extends Logging {
255254
case atom => atom
256255
}
257256

258-
private def parseJson(jsonIter: Iterator[String]): Iterator[Map[String, Any]] = {
257+
private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = {
259258
// According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72],
260259
// ObjectMapper will not return BigDecimal when
261260
// "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled
@@ -264,15 +263,17 @@ private[sql] object JsonRDD extends Logging {
264263
// for every float number, which will be slow.
265264
// So, right now, we will have Infinity for those BigDecimal number.
266265
// TODO: Support BigDecimal.
267-
// Also, when there is a key appearing multiple times (a duplicate key),
268-
// the ObjectMapper will take the last value associated with this duplicate key.
269-
// For example: for {"key": 1, "key":2}, we will get "key"->2.
270-
val mapper = new ObjectMapper()
271-
jsonIter.map {
272-
record =>
273-
val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
274-
parsed.asInstanceOf[Map[String, Any]]
275-
}
266+
json.mapPartitions(iter => {
267+
// Also, when there is a key appearing multiple times (a duplicate key),
268+
// the ObjectMapper will take the last value associated with this duplicate key.
269+
// For example: for {"key": 1, "key":2}, we will get "key"->2.
270+
val mapper = new ObjectMapper()
271+
iter.map {
272+
record =>
273+
val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
274+
parsed.asInstanceOf[Map[String, Any]]
275+
}
276+
})
276277
}
277278

278279
private def toLong(value: Any): Long = {

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.json
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
21-
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
2220
import org.apache.spark.sql.catalyst.types._
2321
import org.apache.spark.sql.catalyst.util._
2422
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}

0 commit comments

Comments
 (0)