Skip to content

Updates #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 10, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._

private[sql] class DefaultSource extends SchemaRelationProvider {
/** Returns a new base relation with the given parameters. */
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {

/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, None)(sqlContext)
}

/** Returns a new base relation with the given schema and parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, schema)(sqlContext)
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,37 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate

import parquet.hadoop.ParquetInputFormat
import parquet.hadoop.util.ContextUtil

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition => SparkPartition, Logging}
import org.apache.spark.rdd.{NewHadoopPartition, RDD}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate

import org.apache.spark.sql.{SQLConf, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SQLConf, SQLContext}

import scala.collection.JavaConversions._


/**
* Allows creation of parquet based tables using the syntax
* `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
* required is `path`, which should be the location of a collection of, optionally partitioned,
* parquet files.
*/
class DefaultSource extends SchemaRelationProvider {
class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
parameters: Map[String, String]): BaseRelation = {
val path =
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))

ParquetRelation2(path, schema)(sqlContext)
ParquetRelation2(path)(sqlContext)
}
}

Expand Down Expand Up @@ -82,9 +82,7 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files:
* discovery.
*/
@DeveloperApi
case class ParquetRelation2(
path: String,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
extends CatalystScan with Logging {

def sparkContext = sqlContext.sparkContext
Expand Down Expand Up @@ -135,13 +133,12 @@ case class ParquetRelation2(

override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum

val dataSchema = userSpecifiedSchema.getOrElse(
StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))
)
val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))

val dataIncludesKey =
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)

Expand Down
31 changes: 22 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing(
sys.error(s"Failed to load class for data source: $provider")
}
}
val relation = clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema)

val relation = userSpecifiedSchema match {
case Some(schema: StructType) => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
}
}
case None => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
}
}
}

sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ trait SchemaRelationProvider {
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation
schema: StructType): BaseRelation
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
}
}

case class AllDataTypesScan(
from: Int,
to: Int,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
extends TableScan {

override def schema = userSpecifiedSchema.get
override def schema = userSpecifiedSchema

override def buildScan() = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
Expand Down