diff --git a/.gitignore b/.gitignore index a4ec12ca6b53f..7ec8d45e12c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ metastore_db/ metastore/ warehouse/ TempStatsStore/ +sql/hive-thriftserver/test_warehouses diff --git a/assembly/pom.xml b/assembly/pom.xml index 567a8dd2a0d94..703f15925bc44 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -165,6 +165,16 @@ + + hive-thriftserver + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bagel/pom.xml b/bagel/pom.xml index 90c4b095bb611..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-bagel_2.10 - bagel + bagel jar Spark Project Bagel diff --git a/bin/beeline b/bin/beeline new file mode 100755 index 0000000000000..09fe366c609fa --- /dev/null +++ b/bin/beeline @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +classpath_output=$($FWDIR/bin/compute-classpath.sh) +if [[ "$?" != "0" ]]; then + echo "$classpath_output" + exit 1 +else + CLASSPATH=$classpath_output +fi + +CLASS="org.apache.hive.beeline.BeeLine" +exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index e81e8c060cb98..16b794a1592e8 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" fi diff --git a/bin/spark-shell b/bin/spark-shell index 850e9507ec38f..756c8179d12b6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -46,11 +46,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 4b9708a8c03f3..b56d69801171c 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql new file mode 100755 index 0000000000000..bba7f897b19bc --- /dev/null +++ b/bin/spark-sql @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/spark-sql [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/core/pom.xml b/core/pom.xml index 1054cec4d77bb..a24743495b0e1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-core_2.10 - core + core jar Spark Project Core diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index f010c03223ef4..09a60571238ea 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,7 +19,6 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -63,8 +62,7 @@ class ShuffleDependency[K, V, C]( val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false, - val sortOrder: Option[SortOrder] = None) + val mapSideCombine: Boolean = false) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3b5642b6caa36..c9cec33ebaa66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -46,6 +46,10 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + // A special jar name that indicates the class being run is inside of Spark itself, and therefore + // no user jar is needed. + private val SPARK_INTERNAL = "spark-internal" + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -257,7 +261,9 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (clusterManager == YARN && deployMode == CLUSTER) { childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } @@ -332,7 +338,7 @@ object SparkSubmit { * Return whether the given primary resource represents a user jar. */ private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) + !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) } /** @@ -349,6 +355,10 @@ object SparkSubmit { primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL } + private[spark] def isInternal(primaryResource: String): Boolean = { + primaryResource == SPARK_INTERNAL + } + /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3ab67a43a3b55..01d0ae541a66b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -204,8 +204,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - // Delineates parsing of Spark options from parsing of user options. var inSparkOpts = true + + // Delineates parsing of Spark options from parsing of user options. parse(opts) def parse(opts: Seq[String]): Unit = opts match { @@ -318,7 +319,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { SparkSubmit.printErrorAndExit(errMessage) case v => primaryResource = - if (!SparkSubmit.isShell(v)) { + if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { Utils.resolveURI(v).toString } else { v diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index afd7075f686b9..d85f962783931 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -58,12 +58,6 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) new ShuffledRDD[K, V, V, P](self, part) - .setKeyOrdering(ordering) - .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) + .setKeyOrdering(if (ascending) ordering else ordering.reverse) } } - -private[spark] object SortOrder extends Enumeration { - type SortOrder = Value - val ASCENDING, DESCENDING = Value -} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index da4a8c3dc22b1..bf02f68d0d3d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -52,8 +51,6 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false - private var sortOrder: Option[SortOrder] = None - /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) @@ -78,15 +75,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( this } - /** Set sort order for RDD's sorting. */ - def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { - this.sortOrder = Option(sortOrder) - this - } - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, - keyOrdering, aggregator, mapSideCombine, sortOrder)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 76cdb8f4f8e8a..c8059496a1bdf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,7 +18,6 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -51,16 +50,22 @@ class HashShuffleReader[K, C]( iter } - val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { - val buf = aggregatedIter.toArray - if (sortOrder == SortOrder.ASCENDING) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Define a Comparator for the whole record based on the key Ordering. + val cmp = new Ordering[Product2[K, C]] { + override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = { + keyOrd.compare(o1._1, o2._1) + } + } + val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray + // TODO: do external sort. + scala.util.Sorting.quickSort(sortBuffer)(cmp) + sortBuffer.iterator + case None => + aggregatedIter } - - sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala index d769b54fa2fae..f77488ef3d449 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import org.apache.commons.lang.SystemUtils +import org.apache.commons.lang3.SystemUtils import org.slf4j.Logger import sun.misc.{Signal, SignalHandler} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 38830103d1e8d..33de24d1ae6d7 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare @@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ release:perform cd .. @@ -111,10 +111,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" make_binary_release "hadoop2" \ - "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/run-tests b/dev/run-tests index 51e4def0f835a..98ec969dc1b37 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -65,7 +65,7 @@ echo "=========================================================================" # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ + echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \ assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" else echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ diff --git a/dev/scalastyle b/dev/scalastyle index a02d06912f238..d9f2b91a3a091 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38728534a46e0..156e0aebdebe6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD // Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, +// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, // you can use custom classes that implement the Product interface. case class Person(name: String, age: Int) @@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect() - # Writing Language-Integrated Relational Queries **Language-Integrated queries are currently only supported in Scala.** @@ -573,4 +572,200 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres evaluated by the SQL execution engine. A full list of the functions supported can be found in the [ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - \ No newline at end of file + + +## Running the Thrift JDBC server + +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] +(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test +the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive +you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` +for maven). + +To start the JDBC server, run the following in the Spark directory: + + ./sbin/start-thriftserver.sh + +The default port the server listens on is 10000. To listen on customized host and port, please set +the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may +run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can +use beeline to test the Thrift JDBC server: + + ./bin/beeline + +Connect to the JDBC server in beeline with: + + beeline> !connect jdbc:hive2://localhost:10000 + +Beeline will ask you for a username and password. In non-secure mode, simply enter the username on +your machine and a blank password. For secure mode, please follow the instructions given in the +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +You may also use the beeline script comes with Hive. + +### Migration Guide for Shark Users + +#### Reducer number + +In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark +SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +is 200. Users may customize this property via `SET`: + +``` +SET spark.sql.shuffle.partitions=10; +SELECT page, count(*) c FROM logs_last_month_cached +GROUP BY page ORDER BY c DESC LIMIT 10; +``` + +You may also put this property in `hive-site.xml` to override the default value. + +For now, the `mapred.reduce.tasks` property is still recognized, and is converted to +`spark.sql.shuffle.partitions` automatically. + +#### Caching + +The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no +longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +let user control table caching explicitly: + +``` +CACHE TABLE logs_last_month; +UNCACHE TABLE logs_last_month; +``` + +**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", +but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be +cached, you may simply count the table immediately after executing `CACHE TABLE`: + +``` +CACHE TABLE logs_last_month; +SELECT COUNT(1) FROM logs_last_month; +``` + +Several caching related features are not supported yet: + +* User defined partition level cache eviction policy +* RDD reloading +* In-memory cache write through policy + +### Compatibility with Apache Hive + +#### Deploying in Exising Hive Warehouses + +Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +installations. You do not need to modify your existing Hive Metastore or change the data placement +or partitioning of your tables. + +#### Supported Hive Features + +Spark SQL supports the vast majority of Hive features, such as: + +* Hive query statements, including: + * `SELECT` + * `GROUP BY + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` +* All Hive operators, including: + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathemtatical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) +* User defined functions (UDF) +* User defined aggregation functions (UDAF) +* User defined serialization formats (SerDe's) +* Joins + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` +* Unions +* Sub queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sampling +* Explain +* Partitioned tables +* All Hive DDL Functions, including: + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` +* Most Hive Data types, including: + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +#### Unsupported Hive Functionality + +Below is a list of Hive features that we don't support yet. Most of these features are rarely used +in Hive deployments. + +**Major Hive Features** + +* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL + doesn't support buckets yet. + +**Esoteric Hive Features** + +* Tables with partitions using different input formats: In Spark SQL, all table partitions need to + have the same input format. +* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions + (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. +* `UNIONTYPE` +* Unique join +* Single query multi insert +* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at + the moment. + +**Hive Input/Output Formats** + +* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat. +* Hadoop archive + +**Hive Optimizations** + +A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are +not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +releases of Spark SQL. + +* Block level bitmap indexes and virtual columns (used to build indexes) +* Automatically convert a join to map join: For joining a large table with multiple small tables, + Hive automatically converts the join into a map join. We are adding this auto conversion in the + next release. +* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you + need to control the degree of parallelism post-shuffle using "SET + spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the + next release. +* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still + launches tasks to compute the result. +* Skew data flag: Spark SQL does not follow the skew data flags in Hive. +* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. +* Merge multiple small files for query results: if the result output contains multiple small files, + Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS + metadata. Spark SQL does not support that. + +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. diff --git a/examples/pom.xml b/examples/pom.xml index bd1c387c2eb91..c4ed0f5a6a02b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-examples_2.10 - examples + examples jar Spark Project Examples diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 61a6aff543aed..874b8a7959bb6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-flume_2.10 - streaming-flume + streaming-flume jar Spark Project External Flume diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4762c50685a93..25a5c0a4d7d77 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-kafka_2.10 - streaming-kafka + streaming-kafka jar Spark Project External Kafka diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 32c530e600ce0..f31ed655f6779 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-mqtt_2.10 - streaming-mqtt + streaming-mqtt jar Spark Project External MQTT diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 637adb0f00da0..56bb24c2a072e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-twitter_2.10 - streaming-twitter + streaming-twitter jar Spark Project External Twitter diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e4d758a04a4cd..54b0242c54e78 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-zeromq_2.10 - streaming-zeromq + streaming-zeromq jar Spark Project External ZeroMQ diff --git a/graphx/pom.xml b/graphx/pom.xml index 7e3bcf29dcfbc..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-graphx_2.10 - graphx + graphx jar Spark Project GraphX diff --git a/mllib/pom.xml b/mllib/pom.xml index 92b07e2357db1..f27cf520dc9fa 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-mllib_2.10 - mllib + mllib jar Spark Project ML Library diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3f6ff859374c7..da7c633bbd2af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.01) + assert(model.intercept ~== 1.97 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 34bc4537a7b3a..afa1f79b95a12 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,8 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class KMeansSuite extends FunSuite with LocalSparkContext { @@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("no distinct points") { @@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.size === 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("single cluster with sparse data") { @@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() } test("k-means|| initialization") { + + case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { + @Override def compare(that: VectorWithCompare): Int = { + if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + } + } + val points = Seq( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), @@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // unselected point as long as it hasn't yet selected all of them var model = KMeans.train(rdd, k = 5, maxIterations = 1) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Iterations of Lloyd's should not change the answer either model = KMeans.train(rdd, k = 5, maxIterations = 10) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Neither should more runs model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 1c9844f289fe0..994e0feb8629e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 - assert(AreaUnderCurve.of(curve) === auc) + assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) == auc) + assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5) } test("auc of an empty curve") { val curve = Seq.empty[(Double, Double)] - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } test("auc of a curve with a single point") { val curve = Seq((1.0, 1.0)) - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 94db1dc183230..a733f88b60b80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals +import org.apache.spark.mllib.util.TestingUtils._ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { - // TODO: move utility functions to TestingUtils. + def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { - actual.zip(expected).forall { case (x1, x2) => - x1.almostEquals(x2) - } - } - - def elementsAlmostEqual( - actual: Seq[(Double, Double)], - expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { - actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => - x1.almostEquals(x2) && y1.almostEquals(y2) - } - } + def cond2(x: ((Double, Double), (Double, Double))): Boolean = + (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( @@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr - val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) - assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) - assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) - assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) - assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) - assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) - assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) + + assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) + assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) + assert(metrics.pr().collect().zip(prCurve).forall(cond2)) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) + assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) + assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) + assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) + assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index dfb2eb7f0d14e..bf040110e228b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - - assert(compareDouble( - loss1(0), - loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + - math.pow(initialWeightsWithIntercept(1), 2)) / 2), + assert( + loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5, """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") assert( - compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && - compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + (newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) && + (newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5), "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ff414742e8393..5f4c24115ac80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { @@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { lazy val dataRDD = sc.parallelize(data, 2).cache() - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - test("LBFGS loss should be decreasing and match the result of Gradient Descent.") { val regParam = 0 @@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { miniBatchFrac, initialWeightsWithIntercept) - assert(compareDouble(lossGD(0), lossLBFGS(0)), + assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") // The 2% difference here is based on observation, but is not theoretically guaranteed. - assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02), + assert(lossGD.last ~= lossLBFGS.last relTol 0.02, "The last losses of LBFGS and GD should be within 2% difference.") - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } @@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { initialWeightsWithIntercept) // for class LBFGS and the optimize method, we only look at the weights - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bbf385229081a..b781a6aed9a8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -21,7 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} +import org.jblas.{DoubleMatrix, SimpleBlas} + +import org.apache.spark.mllib.util.TestingUtils._ class NNLSSuite extends FunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ @@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) val x = NNLS.solve(ata, atb, ws) for (i <- 0 until n) { - assert(Math.abs(x(i) - goodx(i)) < 1e-3) + assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 4b7b019d820b4..db13f142df517 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(-1.0, 0.0, 6.0)) .add(Vectors.dense(3.0, -3.0, 0.0)) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(1.7, -0.6, 0.0)) .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer = summarizer1.merge(summarizer2) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer) assert(summarizer3.count === 0) - assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") - assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 64b1ba7527183..29cc42d8cbea7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -18,28 +18,155 @@ package org.apache.spark.mllib.util import org.apache.spark.mllib.linalg.Vector +import org.scalatest.exceptions.TestFailedException object TestingUtils { + val ABS_TOL_MSG = " using absolute tolerance" + val REL_TOL_MSG = " using relative tolerance" + + /** + * Private helper function for comparing two values using relative tolerance. + * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, + * the relative tolerance is meaningless, so the exception will be raised to warn users. + */ + private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + val absX = math.abs(x) + val absY = math.abs(y) + val diff = math.abs(x - y) + if (x == y) { + true + } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { + throw new TestFailedException( + s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) + } else { + diff < eps * math.min(absX, absY) + } + } + + /** + * Private helper function for comparing two values using absolute tolerance. + */ + private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + math.abs(x - y) < eps + } + + case class CompareDoubleRightSide( + fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String) + + /** + * Implicit class for comparing two double values using relative tolerance or absolute tolerance. + */ implicit class DoubleWithAlmostEquals(val x: Double) { - // An improved version of AlmostEquals would always divide by the larger number. - // This will avoid the problem of diving by zero. - def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = { - if(x == y) { - true - } else if(math.abs(x) > math.abs(y)) { - math.abs(x - y) / math.abs(x) < epsilon - } else { - math.abs(x - y) / math.abs(y) < epsilon + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two values are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two values are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareDoubleRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0) } + true } + + /** + * Throws exception when the difference of two values are within eps; otherwise, returns true. + */ + def !~==(r: CompareDoubleRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, + x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. + */ + def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, + x, eps, REL_TOL_MSG) + + override def toString = x.toString } + case class CompareVectorRightSide( + fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String) + + /** + * Implicit class for comparing two vectors using relative tolerance or absolute tolerance. + */ implicit class VectorWithAlmostEquals(val x: Vector) { - def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = { - x.toArray.corresponds(y.toArray) { - _.almostEquals(_, epsilon) + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareVectorRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) } + true } + + /** + * Throws exception when the difference of two vectors are within eps; otherwise, returns true. + */ + def !~==(r: CompareVectorRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala new file mode 100644 index 0000000000000..b0ecb33c28483 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.mllib.linalg.Vectors +import org.scalatest.FunSuite +import org.apache.spark.mllib.util.TestingUtils._ +import org.scalatest.exceptions.TestFailedException + +class TestingUtilsSuite extends FunSuite { + + test("Comparing doubles using relative error.") { + + assert(23.1 ~== 23.52 relTol 0.02) + assert(23.1 ~== 22.74 relTol 0.02) + assert(23.1 ~= 23.52 relTol 0.02) + assert(23.1 ~= 22.74 relTol 0.02) + assert(!(23.1 !~= 23.52 relTol 0.02)) + assert(!(23.1 !~= 22.74 relTol 0.02)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02) + intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02) + intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02) + intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02) + + assert(23.1 !~== 23.63 relTol 0.02) + assert(23.1 !~== 22.34 relTol 0.02) + assert(23.1 !~= 23.63 relTol 0.02) + assert(23.1 !~= 22.34 relTol 0.02) + assert(!(23.1 ~= 23.63 relTol 0.02)) + assert(!(23.1 ~= 22.34 relTol 0.02)) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032) + intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032) + + // Comparisons of numbers very close to zero. + assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01) + assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01) + + assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012) + assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012) + } + + test("Comparing doubles using absolute error.") { + + assert(17.8 ~== 17.99 absTol 0.2) + assert(17.8 ~== 17.61 absTol 0.2) + assert(17.8 ~= 17.99 absTol 0.2) + assert(17.8 ~= 17.61 absTol 0.2) + assert(!(17.8 !~= 17.99 absTol 0.2)) + assert(!(17.8 !~= 17.61 absTol 0.2)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2) + intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2) + intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2) + intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2) + + assert(17.8 !~== 18.01 absTol 0.2) + assert(17.8 !~== 17.59 absTol 0.2) + assert(17.8 !~= 18.01 absTol 0.2) + assert(17.8 !~= 17.59 absTol 0.2) + assert(!(17.8 ~= 18.01 absTol 0.2)) + assert(!(17.8 ~= 17.59 absTol 0.2)) + + // Comparisons of numbers very close to zero, and both side of zeros + assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + } + + test("Comparing vectors using relative error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) + assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) + + // Should throw exception with message when test fails. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01) + + // Comparisons of two sparse vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== + Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) + + assert(Vectors.dense(Array(3.1, 3.5)) !~== + Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01) + } + + test("Comparing vectors using absolute error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + // Comparisons of two sparse vectors + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + // Comparisons of a dense vector and a sparse vector + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6) + } +} diff --git a/pom.xml b/pom.xml index d2e6b3c0ed5a4..93ef3b91b5bce 100644 --- a/pom.xml +++ b/pom.xml @@ -252,9 +252,9 @@ 3.3.2 - commons-codec - commons-codec - 1.5 + commons-codec + commons-codec + 1.5 com.google.code.findbugs @@ -1139,5 +1139,15 @@ + + hive-thriftserver + + false + + + sql/hive-thriftserver + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 62576f84dd031..1629bc2cba8ba 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, - streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", - "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql, + streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", + "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") @@ -100,7 +100,7 @@ object SparkBuild extends PomBuild { Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) - case _ => + case _ => } override val userPropertiesMap = System.getProperties.toMap @@ -158,7 +158,7 @@ object SparkBuild extends PomBuild { /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)). foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh new file mode 100755 index 0000000000000..8398e6f19b511 --- /dev/null +++ b/sbin/start-thriftserver.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL Thrift server + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-thriftserver [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62d..531bfddbf237b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -32,7 +32,7 @@ Spark Project Catalyst http://spark.apache.org/ - catalyst + catalyst diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d274..a357c6ffb8977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + BoundReference(1, AttributeReference("", StringType, nullable = false)())) } /** diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c309c43804d97..3a038a2db6173 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -32,7 +32,7 @@ Spark Project SQL http://spark.apache.org/ - sql + sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2b787e14f3f15..41920c00b5a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,12 +30,13 @@ import scala.collection.JavaConverters._ * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). */ trait SQLConf { + import SQLConf._ /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -43,11 +44,10 @@ trait SQLConf { * effectively disables auto conversion. * Hive setting: hive.auto.convert.join.noconditionaltask.size. */ - private[spark] def autoConvertJoinSize: Int = - get("spark.sql.auto.convert.join.size", "10000").toInt + private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "") /** ********************** SQLConf functionality methods ************ */ @@ -61,7 +61,7 @@ trait SQLConf { def set(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for ${key}") + require(value != null, s"value cannot be null for $key") settings.put(key, value) } @@ -90,3 +90,13 @@ trait SQLConf { } } + +object SQLConf { + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 98d2f89c8ae71..9293239131d52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { /** @@ -44,28 +45,53 @@ trait Command { case class SetCommand( key: Option[String], value: Option[String], output: Seq[Attribute])( @transient context: SQLContext) - extends LeafNode with Command { + extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => - context.set(k, v) - Array(k -> v) + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") + context.set(SQLConf.SHUFFLE_PARTITIONS, v) + Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + } else { + context.set(k, v) + Array(s"$k=$v") + } // Query the value bound to key k. case (Some(k), _) => - Array(k -> context.getOption(k).getOrElse("")) + // TODO (lian) This is just a workaround to make the Simba ODBC driver work. + // Should remove this once we get the ODBC driver updated. + if (k == "-v") { + val hiveJars = Seq( + "hive-exec-0.12.0.jar", + "hive-service-0.12.0.jar", + "hive-common-0.12.0.jar", + "hive-hwi-0.12.0.jar", + "hive-0.12.0.jar").mkString(":") + + Array( + "system:java.class.path=" + hiveJars, + "system:sun.java.command=shark.SharkServer2") + } + else { + Array(s"$k=${context.getOption(k).getOrElse("")}") + } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll + context.getAll.map { case (k, v) => + s"$k=$v" + } case _ => throw new IllegalArgumentException() } def execute(): RDD[Row] = { - val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } context.sparkContext.parallelize(rows, 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 08293f7f0ca30..1a58d73d9e7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest { assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(get("some.property", "0") == "20") + sql("set some.property = 40") + assert(get("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" @@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest { clear() } + test("deprecated property") { + clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6736189c96d4b..de9e8aa4f62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) + Seq(Seq(s"$nonexistentKey=")) ) clear() } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000000..7fac90fdc596d --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive-thriftserver_2.10 + jar + Spark Project Hive + http://spark.apache.org/ + + hive-thriftserver + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + org.spark-project.hive + hive-cli + ${hive.version} + + + org.spark-project.hive + hive-jdbc + ${hive.version} + + + org.spark-project.hive + hive-beeline + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000000..ddbc2a79fb512 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logger.warn("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logger.debug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logger.info("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logger.info("HiveThriftServer2 started") + } catch { + case e: Exception => + logger.error("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000..599294dfbb7d7 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000000..27268ecb923e9 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.sql.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + ret = driver.run(cmd).getResponseCode + if (ret != 0) { + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000000..42cbf363b274f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000..5202aa9903e03 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logger.debug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + val execution = context.executePlan(context.hql(command).logicalPlan) + + // TODO unify the error code + try { + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logger.error(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000..451c3bd7b9352 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logger.debug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logger.debug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..6b3275b4eaf04 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000000..a4e1f3e762e89 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logger.debug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logger.info(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.hql(statement) + logger.debug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logger.error("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000000..850f8014b6f05 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000000..69f19f826a802 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000000..fe3403b3292ec --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.net.ServerSocket +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + val HOST = "localhost" + val PORT = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + val environment = pb.environment() + environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) + environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000000..bb2242618fbef --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 1699ffe06ce15..93d00f7c37c9b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -32,7 +32,7 @@ Spark Project Hive http://spark.apache.org/ - hive + hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 201c85f3d501e..84d43eaeea51d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -255,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DecimalType, TimestampType, BinaryType) - protected def toHiveString(a: (Any, DataType)): String = a match { + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6ab68b563f8d..d18ccf8167487 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -610,7 +610,7 @@ private[hive] object HiveQl { // TOK_DESTINATION means to overwrite the table. val resultDestination = (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = if (intoClause.isEmpty) true else false + val overwrite = intoClause.isEmpty nodeToDest( resultDestination, withLimit, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c3942578d6b5a..82c88280d7754 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -24,6 +24,8 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -31,13 +33,16 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} +import org.apache.spark.sql.catalyst.types.DataType + /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[_] + def makeRDDForTable(hiveTable: HiveTable): RDD[Row] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] } @@ -46,7 +51,10 @@ private[hive] sealed trait TableReader { * data warehouse directory. */ private[hive] -class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) +class HadoopTableReader( + @transient attributes: Seq[Attribute], + @transient relation: MetastoreRelation, + @transient sc: HiveContext) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def hiveConf = _broadcastedHiveConf.value.value - override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[_] = { + filterOpt: Option[PathFilter]): RDD[Row] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val tablePath = hiveTable.getPath @@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val attrsWithIndex = attributes.zipWithIndex + val mutableRow = new GenericMutableRow(attrsWithIndex.length) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - // Deserialize each Writable to get the row value. - iter.map { - case v: Writable => deserializer.deserialize(v) - case value => - sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") - } + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[_] = { - + partitionToDeserializer: Map[HivePartition, + Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[Row] = { val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = partition.getPartitionPath @@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon } // Create local references so that the outer object isn't serialized. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer + val mutableRow = new GenericMutableRow(attributes.length) + + // split the attributes (output schema) into 2 categories: + // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the + // index of the attribute in the output Row. + val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { + relation.partitionKeys.indexOf(attr._1) >= 0 + }) + + def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { + partitionKeys.foreach { case (attr, ordinal) => + // get partition key ordinal for a given attribute + val partOridinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + } + } + // fill the partition key for the given MutableRow Object + fillPartitionKeys(partValues, mutableRow) val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) hivePartitionRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val rowWithPartArr = new Array[Object](2) - - // The update and deserializer initialization are intentionally - // kept out of the below iter.map loop to save performance. - rowWithPartArr.update(1, partValues) val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // Map each tuple to a row object - iter.map { value => - val deserializedRow = deserializer.deserialize(value) - rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.asInstanceOf[Object] - } + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Object](sc.sparkContext) + new EmptyRDD[Row](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } - } -private[hive] object HadoopTableReader { +private[hive] object HadoopTableReader extends HiveInspectors { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. @@ -241,4 +254,40 @@ private[hive] object HadoopTableReader { val bufferSize = System.getProperty("spark.buffer.size", "65536") jobConf.set("io.file.buffer.size", bufferSize) } + + /** + * Transform the raw data(Writable object) into the Row object for an iterable input + * @param iter Iterable input which represented as Writable object + * @param deserializer Deserializer associated with the input writable object + * @param attrs Represents the row attribute names and its zero-based position in the MutableRow + * @param row reusable MutableRow object + * + * @return Iterable Row object that transformed from the given iterable input. + */ + def fillObject( + iter: Iterator[Writable], + deserializer: Deserializer, + attrs: Seq[(Attribute, Int)], + row: GenericMutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] + // get the field references according to the attributes(output of the reader) required + val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + + // Map each tuple to a row object + iter.map { value => + val raw = deserializer.deserialize(value) + var idx = 0; + while (idx < fieldRefs.length) { + val fieldRef = fieldRefs(idx)._1 + val fieldIdx = fieldRefs(idx)._2 + val fieldValue = soi.getStructFieldData(raw, fieldRef) + + row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) + + idx += 1 + } + + row: Row + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index e7016fa16eea9..8920e2a76a27f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.util.MutablePair /** * :: DeveloperApi :: @@ -50,8 +49,7 @@ case class HiveTableScan( relation: MetastoreRelation, partitionPruningPred: Option[Expression])( @transient val context: HiveContext) - extends LeafNode - with HiveInspectors { + extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") @@ -67,42 +65,7 @@ case class HiveTableScan( } @transient - private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context) - - /** - * The hive object inspector for this table, which can be used to extract values from the - * serialized row representation. - */ - @transient - private[this] lazy val objectInspector = - relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] - - /** - * Functions that extract the requested attributes from the hive output. Partitioned values are - * casted from string to its declared data type. - */ - @transient - protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { - attributes.map { a => - val ordinal = relation.partitionKeys.indexOf(a) - if (ordinal >= 0) { - val dataType = relation.partitionKeys(ordinal).dataType - (_: Any, partitionKeys: Array[String]) => { - castFromString(partitionKeys(ordinal), dataType) - } - } else { - val ref = objectInspector.getAllStructFieldRefs - .find(_.getFieldName == a.name) - .getOrElse(sys.error(s"Can't find attribute $a")) - val fieldObjectInspector = ref.getFieldObjectInspector - - (row: Any, _: Array[String]) => { - val data = objectInspector.getStructFieldData(row, ref) - unwrapData(data, fieldObjectInspector) - } - } - } - } + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -114,6 +77,7 @@ case class HiveTableScan( val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") if (attributes.size == relation.output.size) { + // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` ColumnProjectionUtils.setFullyReadColumns(hiveConf) } else { ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) @@ -140,12 +104,6 @@ case class HiveTableScan( addColumnMetadataToConf(context.hiveconf) - private def inputRdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) - } - /** * Prunes partitions not involve the query plan. * @@ -169,44 +127,10 @@ case class HiveTableScan( } } - override def execute() = { - inputRdd.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val mutableRow = new GenericMutableRow(attributes.length) - val mutablePair = new MutablePair[Any, Array[String]]() - val buffered = iterator.buffered - - // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern - // matching are avoided intentionally. - val rowsAndPartitionKeys = buffered.head match { - // With partition keys - case _: Array[Any] => - buffered.map { case array: Array[Any] => - val deserializedRow = array(0) - val partitionKeys = array(1).asInstanceOf[Array[String]] - mutablePair.update(deserializedRow, partitionKeys) - } - - // Without partition keys - case _ => - val emptyPartitionKeys = Array.empty[String] - buffered.map { deserializedRow => - mutablePair.update(deserializedRow, emptyPartitionKeys) - } - } - - rowsAndPartitionKeys.map { pair => - var i = 0 - while (i < attributes.length) { - mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) - i += 1 - } - mutableRow: Row - } - } - } + override def execute() = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output = attributes diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 new file mode 100644 index 0000000000000..f369f21e1833f --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 @@ -0,0 +1,2 @@ +100 100 2010-01-01 +200 200 2010-01-02 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a8623b64c656f..a022a1e2dc70e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -419,10 +419,10 @@ class HiveQuerySuite extends HiveComparisonTest { hql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - hql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + hql("set some.property=20") + assert(get("some.property", "0") == "20") + hql("set some.property = 40") + assert(get("some.property", "0") == "40") hql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) @@ -436,63 +436,61 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. assert(hql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } hql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + hql(s"SET").collect().map(_.getString(0)) } // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + hql(s"SET $nonexistentKey").collect().map(_.getString(0)) } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql("SET").collect().map(_.getString(0)) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql("SET").collect().map(_.getString(0)) } - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala new file mode 100644 index 0000000000000..bcb00f871d185 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.test.TestHive + +class HiveTableScanSuite extends HiveComparisonTest { + // MINOR HACK: You must run a query before calling reset the first time. + TestHive.hql("SHOW TABLES") + TestHive.reset() + + TestHive.hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + | ROW FORMAT SERDE + | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + | STORED AS RCFILE + """.stripMargin) + TestHive.hql("""FROM src + | INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + | SELECT 100,100 LIMIT 1 + """.stripMargin) + TestHive.hql("""ALTER TABLE part_scan_test SET SERDE + | 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' + """.stripMargin) + TestHive.hql("""FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + | SELECT 200,200 LIMIT 1 + """.stripMargin) + + createQueryTest("partition_based_table_scan_with_different_serde", + "SELECT * from part_scan_test", false) +} diff --git a/streaming/pom.xml b/streaming/pom.xml index f60697ce745b7..b99f306b8f2cc 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming_2.10 - streaming + streaming jar Spark Project Streaming diff --git a/tools/pom.xml b/tools/pom.xml index c0ee8faa7a615..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -27,7 +27,7 @@ org.apache.spark spark-tools_2.10 - tools + tools jar Spark Project Tools diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 5b13a1f002d6e..51744ece0412d 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-alpha + yarn-alpha org.apache.spark diff --git a/yarn/pom.xml b/yarn/pom.xml index efb473aa1b261..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -29,7 +29,7 @@ pom Spark Project YARN Parent POM - yarn + yarn diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index ceaf9f9d71001..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-stable + yarn-stable org.apache.spark