Skip to content

Commit 38f634e

Browse files
committed
Remove Option from createRelation.
1 parent 65e9c73 commit 38f634e

File tree

4 files changed

+41
-17
lines changed

4 files changed

+41
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext
2121
import org.apache.spark.sql.catalyst.types.StructType
2222
import org.apache.spark.sql.sources._
2323

24-
private[sql] class DefaultSource extends SchemaRelationProvider {
25-
/** Returns a new base relation with the given parameters. */
24+
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
25+
26+
/** Returns a new base relation with the parameters. */
27+
override def createRelation(
28+
sqlContext: SQLContext,
29+
parameters: Map[String, String]): BaseRelation = {
30+
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
31+
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
32+
33+
JSONRelation(fileName, samplingRatio, None)(sqlContext)
34+
}
35+
36+
/** Returns a new base relation with the given schema and parameters. */
2637
override def createRelation(
2738
sqlContext: SQLContext,
2839
parameters: Map[String, String],
29-
schema: Option[StructType]): BaseRelation = {
40+
schema: StructType): BaseRelation = {
3041
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
3142
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
3243

33-
JSONRelation(fileName, samplingRatio, schema)(sqlContext)
44+
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
3445
}
3546
}
3647

sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing(
190190
sys.error(s"Failed to load class for data source: $provider")
191191
}
192192
}
193-
val relation = clazz.newInstance match {
194-
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
195-
dataSource
196-
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
197-
.createRelation(sqlContext, new CaseInsensitiveMap(options))
198-
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
199-
dataSource
200-
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
201-
.createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema)
193+
194+
val relation = userSpecifiedSchema match {
195+
case Some(schema: StructType) => {
196+
clazz.newInstance match {
197+
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
198+
dataSource
199+
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
200+
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
201+
case _ =>
202+
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
203+
}
204+
}
205+
case None => {
206+
clazz.newInstance match {
207+
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
208+
dataSource
209+
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
210+
.createRelation(sqlContext, new CaseInsensitiveMap(options))
211+
case _ =>
212+
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
213+
}
214+
}
202215
}
203216

204217
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ trait SchemaRelationProvider {
6868
def createRelation(
6969
sqlContext: SQLContext,
7070
parameters: Map[String, String],
71-
schema: Option[StructType]): BaseRelation
71+
schema: StructType): BaseRelation
7272
}
7373

7474
/**

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
4545
override def createRelation(
4646
sqlContext: SQLContext,
4747
parameters: Map[String, String],
48-
schema: Option[StructType]): BaseRelation = {
48+
schema: StructType): BaseRelation = {
4949
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
5050
}
5151
}
5252

5353
case class AllDataTypesScan(
5454
from: Int,
5555
to: Int,
56-
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
56+
userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
5757
extends TableScan {
5858

59-
override def schema = userSpecifiedSchema.get
59+
override def schema = userSpecifiedSchema
6060

6161
override def buildScan() = {
6262
sqlContext.sparkContext.parallelize(from to to).map { i =>

0 commit comments

Comments
 (0)