diff --git a/.gitignore b/.gitignore
index a4ec12ca6b53f..7ec8d45e12c6b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -58,3 +58,4 @@ metastore_db/
metastore/
warehouse/
TempStatsStore/
+sql/hive-thriftserver/test_warehouses
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 567a8dd2a0d94..703f15925bc44 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -165,6 +165,16 @@
+
+ hive-thriftserver
+
+
+ org.apache.spark
+ spark-hive-thriftserver_${scala.binary.version}
+ ${project.version}
+
+
+
spark-ganglia-lgpl
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 90c4b095bb611..bd51b112e26fa 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-bagel_2.10
- bagel
+ bagel
jar
Spark Project Bagel
diff --git a/bin/beeline b/bin/beeline
new file mode 100755
index 0000000000000..09fe366c609fa
--- /dev/null
+++ b/bin/beeline
@@ -0,0 +1,45 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+# Find the java binary
+if [ -n "${JAVA_HOME}" ]; then
+ RUNNER="${JAVA_HOME}/bin/java"
+else
+ if [ `command -v java` ]; then
+ RUNNER="java"
+ else
+ echo "JAVA_HOME is not set" >&2
+ exit 1
+ fi
+fi
+
+# Compute classpath using external script
+classpath_output=$($FWDIR/bin/compute-classpath.sh)
+if [[ "$?" != "0" ]]; then
+ echo "$classpath_output"
+ exit 1
+else
+ CLASSPATH=$classpath_output
+fi
+
+CLASS="org.apache.hive.beeline.BeeLine"
+exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@"
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index e81e8c060cb98..16b794a1592e8 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes"
fi
diff --git a/bin/spark-shell b/bin/spark-shell
index 850e9507ec38f..756c8179d12b6 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -46,11 +46,11 @@ function main(){
# (see https://github.com/sbt/sbt/issues/562).
stty -icanon min 1 -echo > /dev/null 2>&1
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
- $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main
+ $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@"
stty icanon echo > /dev/null 2>&1
else
export SPARK_SUBMIT_OPTS
- $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main
+ $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@"
fi
}
diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd
index 4b9708a8c03f3..b56d69801171c 100755
--- a/bin/spark-shell.cmd
+++ b/bin/spark-shell.cmd
@@ -19,4 +19,4 @@ rem
set SPARK_HOME=%~dp0..
-cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main
+cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %*
diff --git a/bin/spark-sql b/bin/spark-sql
new file mode 100755
index 0000000000000..bba7f897b19bc
--- /dev/null
+++ b/bin/spark-sql
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Shell script for starting the Spark SQL CLI
+
+# Enter posix mode for bash
+set -o posix
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
+ echo "Usage: ./sbin/spark-sql [options]"
+ $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ exit 0
+fi
+
+CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
+exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@
diff --git a/core/pom.xml b/core/pom.xml
index 1054cec4d77bb..a24743495b0e1 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-core_2.10
- core
+ core
jar
Spark Project Core
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index f010c03223ef4..09a60571238ea 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -19,7 +19,6 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.SortOrder.SortOrder
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
@@ -63,8 +62,7 @@ class ShuffleDependency[K, V, C](
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
- val mapSideCombine: Boolean = false,
- val sortOrder: Option[SortOrder] = None)
+ val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 3b5642b6caa36..c9cec33ebaa66 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -46,6 +46,10 @@ object SparkSubmit {
private val CLUSTER = 2
private val ALL_DEPLOY_MODES = CLIENT | CLUSTER
+ // A special jar name that indicates the class being run is inside of Spark itself, and therefore
+ // no user jar is needed.
+ private val SPARK_INTERNAL = "spark-internal"
+
// Special primary resource names that represent shells rather than application jars.
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
@@ -257,7 +261,9 @@ object SparkSubmit {
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
if (clusterManager == YARN && deployMode == CLUSTER) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
- childArgs += ("--jar", args.primaryResource)
+ if (args.primaryResource != SPARK_INTERNAL) {
+ childArgs += ("--jar", args.primaryResource)
+ }
childArgs += ("--class", args.mainClass)
if (args.childArgs != null) {
args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
@@ -332,7 +338,7 @@ object SparkSubmit {
* Return whether the given primary resource represents a user jar.
*/
private def isUserJar(primaryResource: String): Boolean = {
- !isShell(primaryResource) && !isPython(primaryResource)
+ !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource)
}
/**
@@ -349,6 +355,10 @@ object SparkSubmit {
primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL
}
+ private[spark] def isInternal(primaryResource: String): Boolean = {
+ primaryResource == SPARK_INTERNAL
+ }
+
/**
* Merge a sequence of comma-separated file lists, some of which may be null to indicate
* no files, into a single comma-separated string.
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 3ab67a43a3b55..01d0ae541a66b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -204,8 +204,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
/** Fill in values by parsing user options. */
private def parseOpts(opts: Seq[String]): Unit = {
- // Delineates parsing of Spark options from parsing of user options.
var inSparkOpts = true
+
+ // Delineates parsing of Spark options from parsing of user options.
parse(opts)
def parse(opts: Seq[String]): Unit = opts match {
@@ -318,7 +319,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
SparkSubmit.printErrorAndExit(errMessage)
case v =>
primaryResource =
- if (!SparkSubmit.isShell(v)) {
+ if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) {
Utils.resolveURI(v).toString
} else {
v
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index afd7075f686b9..d85f962783931 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -58,12 +58,6 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V, P](self, part)
- .setKeyOrdering(ordering)
- .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING)
+ .setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
}
-
-private[spark] object SortOrder extends Enumeration {
- type SortOrder = Value
- val ASCENDING, DESCENDING = Value
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index da4a8c3dc22b1..bf02f68d0d3d3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.SortOrder.SortOrder
import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
@@ -52,8 +51,6 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
private var mapSideCombine: Boolean = false
- private var sortOrder: Option[SortOrder] = None
-
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
this.serializer = Option(serializer)
@@ -78,15 +75,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
this
}
- /** Set sort order for RDD's sorting. */
- def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = {
- this.sortOrder = Option(sortOrder)
- this
- }
-
override def getDependencies: Seq[Dependency[_]] = {
- List(new ShuffleDependency(prev, part, serializer,
- keyOrdering, aggregator, mapSideCombine, sortOrder))
+ List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}
override val partitioner = Some(part)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 76cdb8f4f8e8a..c8059496a1bdf 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -18,7 +18,6 @@
package org.apache.spark.shuffle.hash
import org.apache.spark.{InterruptibleIterator, TaskContext}
-import org.apache.spark.rdd.SortOrder
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
@@ -51,16 +50,22 @@ class HashShuffleReader[K, C](
iter
}
- val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield {
- val buf = aggregatedIter.toArray
- if (sortOrder == SortOrder.ASCENDING) {
- buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
- } else {
- buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
- }
+ // Sort the output if there is a sort ordering defined.
+ dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Define a Comparator for the whole record based on the key Ordering.
+ val cmp = new Ordering[Product2[K, C]] {
+ override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = {
+ keyOrd.compare(o1._1, o2._1)
+ }
+ }
+ val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray
+ // TODO: do external sort.
+ scala.util.Sorting.quickSort(sortBuffer)(cmp)
+ sortBuffer.iterator
+ case None =>
+ aggregatedIter
}
-
- sortedIter.getOrElse(aggregatedIter)
}
/** Close this reader */
diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala
index d769b54fa2fae..f77488ef3d449 100644
--- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala
+++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util
-import org.apache.commons.lang.SystemUtils
+import org.apache.commons.lang3.SystemUtils
import org.slf4j.Logger
import sun.misc.{Signal, SignalHandler}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 38830103d1e8d..33de24d1ae6d7 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
-Dmaven.javadoc.skip=true \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
--batch-mode release:prepare
@@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Dmaven.javadoc.skip=true \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
release:perform
cd ..
@@ -111,10 +111,10 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4"
-make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
+make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4"
+make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
make_binary_release "hadoop2" \
- "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
+ "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
# Copy data
echo "Copying release tarballs"
diff --git a/dev/run-tests b/dev/run-tests
index 51e4def0f835a..98ec969dc1b37 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -65,7 +65,7 @@ echo "========================================================================="
# (either resolution or compilation) prompts the user for input either q, r,
# etc to quit or retry. This echo is there to make it not block.
if [ -n "$_RUN_SQL_TESTS" ]; then
- echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \
+ echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \
assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
else
echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \
diff --git a/dev/scalastyle b/dev/scalastyle
index a02d06912f238..d9f2b91a3a091 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,7 +17,7 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
# Check style with YARN alpha built too
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 38728534a46e0..156e0aebdebe6 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
// Define the schema using a case class.
-// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
+// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
// you can use custom classes that implement the Product interface.
case class Person(name: String, age: Int)
@@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect()
-
# Writing Language-Integrated Relational Queries
**Language-Integrated queries are currently only supported in Scala.**
@@ -573,4 +572,200 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres
evaluated by the SQL execution engine. A full list of the functions supported can be found in the
[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
-
\ No newline at end of file
+
+
+## Running the Thrift JDBC server
+
+The Thrift JDBC server implemented here corresponds to the [`HiveServer2`]
+(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test
+the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive
+you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver`
+for maven).
+
+To start the JDBC server, run the following in the Spark directory:
+
+ ./sbin/start-thriftserver.sh
+
+The default port the server listens on is 10000. To listen on customized host and port, please set
+the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may
+run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can
+use beeline to test the Thrift JDBC server:
+
+ ./bin/beeline
+
+Connect to the JDBC server in beeline with:
+
+ beeline> !connect jdbc:hive2://localhost:10000
+
+Beeline will ask you for a username and password. In non-secure mode, simply enter the username on
+your machine and a blank password. For secure mode, please follow the instructions given in the
+[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients)
+
+Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+
+You may also use the beeline script comes with Hive.
+
+### Migration Guide for Shark Users
+
+#### Reducer number
+
+In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark
+SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value
+is 200. Users may customize this property via `SET`:
+
+```
+SET spark.sql.shuffle.partitions=10;
+SELECT page, count(*) c FROM logs_last_month_cached
+GROUP BY page ORDER BY c DESC LIMIT 10;
+```
+
+You may also put this property in `hive-site.xml` to override the default value.
+
+For now, the `mapred.reduce.tasks` property is still recognized, and is converted to
+`spark.sql.shuffle.partitions` automatically.
+
+#### Caching
+
+The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no
+longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to
+let user control table caching explicitly:
+
+```
+CACHE TABLE logs_last_month;
+UNCACHE TABLE logs_last_month;
+```
+
+**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary",
+but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be
+cached, you may simply count the table immediately after executing `CACHE TABLE`:
+
+```
+CACHE TABLE logs_last_month;
+SELECT COUNT(1) FROM logs_last_month;
+```
+
+Several caching related features are not supported yet:
+
+* User defined partition level cache eviction policy
+* RDD reloading
+* In-memory cache write through policy
+
+### Compatibility with Apache Hive
+
+#### Deploying in Exising Hive Warehouses
+
+Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
+installations. You do not need to modify your existing Hive Metastore or change the data placement
+or partitioning of your tables.
+
+#### Supported Hive Features
+
+Spark SQL supports the vast majority of Hive features, such as:
+
+* Hive query statements, including:
+ * `SELECT`
+ * `GROUP BY
+ * `ORDER BY`
+ * `CLUSTER BY`
+ * `SORT BY`
+* All Hive operators, including:
+ * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc)
+ * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc)
+ * Logical operators (`AND`, `&&`, `OR`, `||`, etc)
+ * Complex type constructors
+ * Mathemtatical functions (`sign`, `ln`, `cos`, etc)
+ * String functions (`instr`, `length`, `printf`, etc)
+* User defined functions (UDF)
+* User defined aggregation functions (UDAF)
+* User defined serialization formats (SerDe's)
+* Joins
+ * `JOIN`
+ * `{LEFT|RIGHT|FULL} OUTER JOIN`
+ * `LEFT SEMI JOIN`
+ * `CROSS JOIN`
+* Unions
+* Sub queries
+ * `SELECT col FROM ( SELECT a + b AS col from t1) t2`
+* Sampling
+* Explain
+* Partitioned tables
+* All Hive DDL Functions, including:
+ * `CREATE TABLE`
+ * `CREATE TABLE AS SELECT`
+ * `ALTER TABLE`
+* Most Hive Data types, including:
+ * `TINYINT`
+ * `SMALLINT`
+ * `INT`
+ * `BIGINT`
+ * `BOOLEAN`
+ * `FLOAT`
+ * `DOUBLE`
+ * `STRING`
+ * `BINARY`
+ * `TIMESTAMP`
+ * `ARRAY<>`
+ * `MAP<>`
+ * `STRUCT<>`
+
+#### Unsupported Hive Functionality
+
+Below is a list of Hive features that we don't support yet. Most of these features are rarely used
+in Hive deployments.
+
+**Major Hive Features**
+
+* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
+ doesn't support buckets yet.
+
+**Esoteric Hive Features**
+
+* Tables with partitions using different input formats: In Spark SQL, all table partitions need to
+ have the same input format.
+* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions
+ (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple.
+* `UNIONTYPE`
+* Unique join
+* Single query multi insert
+* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
+ the moment.
+
+**Hive Input/Output Formats**
+
+* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat.
+* Hadoop archive
+
+**Hive Optimizations**
+
+A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
+not necessary due to Spark SQL's in-memory computational model. Others are slotted for future
+releases of Spark SQL.
+
+* Block level bitmap indexes and virtual columns (used to build indexes)
+* Automatically convert a join to map join: For joining a large table with multiple small tables,
+ Hive automatically converts the join into a map join. We are adding this auto conversion in the
+ next release.
+* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you
+ need to control the degree of parallelism post-shuffle using "SET
+ spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the
+ next release.
+* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still
+ launches tasks to compute the result.
+* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
+* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
+* Merge multiple small files for query results: if the result output contains multiple small files,
+ Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
+ metadata. Spark SQL does not support that.
+
+## Running the Spark SQL CLI
+
+The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute
+queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server.
+
+To start the Spark SQL CLI, run the following in the Spark directory:
+
+ ./bin/spark-sql
+
+Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+You may run `./bin/spark-sql --help` for a complete list of all available
+options.
diff --git a/examples/pom.xml b/examples/pom.xml
index bd1c387c2eb91..c4ed0f5a6a02b 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-examples_2.10
- examples
+ examples
jar
Spark Project Examples
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 61a6aff543aed..874b8a7959bb6 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-flume_2.10
- streaming-flume
+ streaming-flume
jar
Spark Project External Flume
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 4762c50685a93..25a5c0a4d7d77 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-kafka_2.10
- streaming-kafka
+ streaming-kafka
jar
Spark Project External Kafka
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 32c530e600ce0..f31ed655f6779 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-mqtt_2.10
- streaming-mqtt
+ streaming-mqtt
jar
Spark Project External MQTT
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 637adb0f00da0..56bb24c2a072e 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-twitter_2.10
- streaming-twitter
+ streaming-twitter
jar
Spark Project External Twitter
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index e4d758a04a4cd..54b0242c54e78 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-zeromq_2.10
- streaming-zeromq
+ streaming-zeromq
jar
Spark Project External ZeroMQ
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 7e3bcf29dcfbc..6dd52fc618b1e 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-graphx_2.10
- graphx
+ graphx
jar
Spark Project GraphX
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 92b07e2357db1..f27cf520dc9fa 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-mllib_2.10
- mllib
+ mllib
jar
Spark Project ML Library
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 3f6ff859374c7..da7c633bbd2af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
object LogisticRegressionSuite {
@@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
val model = lr.run(testRDD)
// Test the weights
- val weight0 = model.weights(0)
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ assert(model.weights(0) ~== -1.52 relTol 0.01)
+ assert(model.intercept ~== 2.00 relTol 0.01)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
val model = lr.run(testRDD, initialWeights)
- val weight0 = model.weights(0)
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ // Test the weights
+ assert(model.weights(0) ~== -1.50 relTol 0.01)
+ assert(model.intercept ~== 1.97 relTol 0.01)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 34bc4537a7b3a..afa1f79b95a12 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,8 +21,9 @@ import scala.util.Random
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
class KMeansSuite extends FunSuite with LocalSparkContext {
@@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// centered at the mean of the points
var model = KMeans.train(data, k = 1, maxIterations = 1)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(
data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
}
test("no distinct points") {
@@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.size === 1)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
initializationMode = K_MEANS_PARALLEL)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
}
test("single cluster with sparse data") {
@@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
var model = KMeans.train(data, k = 1, maxIterations = 1)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 2)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
initializationMode = K_MEANS_PARALLEL)
- assert(model.clusterCenters.head === center)
+ assert(model.clusterCenters.head ~== center absTol 1E-5)
data.unpersist()
}
test("k-means|| initialization") {
+
+ case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] {
+ @Override def compare(that: VectorWithCompare): Int = {
+ if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) >
+ that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1
+ }
+ }
+
val points = Seq(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
@@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// unselected point as long as it hasn't yet selected all of them
var model = KMeans.train(rdd, k = 5, maxIterations = 1)
- assert(Set(model.clusterCenters: _*) === Set(points: _*))
+
+ assert(model.clusterCenters.sortBy(VectorWithCompare(_))
+ .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
// Iterations of Lloyd's should not change the answer either
model = KMeans.train(rdd, k = 5, maxIterations = 10)
- assert(Set(model.clusterCenters: _*) === Set(points: _*))
+ assert(model.clusterCenters.sortBy(VectorWithCompare(_))
+ .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
// Neither should more runs
model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
- assert(Set(model.clusterCenters: _*) === Set(points: _*))
+ assert(model.clusterCenters.sortBy(VectorWithCompare(_))
+ .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
}
test("two clusters") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
index 1c9844f289fe0..994e0feb8629e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
- assert(AreaUnderCurve.of(curve) === auc)
+ assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
- assert(AreaUnderCurve.of(rddCurve) == auc)
+ assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5)
}
test("auc of an empty curve") {
val curve = Seq.empty[(Double, Double)]
- assert(AreaUnderCurve.of(curve) === 0.0)
+ assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
- assert(AreaUnderCurve.of(rddCurve) === 0.0)
+ assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
}
test("auc of a curve with a single point") {
val curve = Seq((1.0, 1.0))
- assert(AreaUnderCurve.of(curve) === 0.0)
+ assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
val rddCurve = sc.parallelize(curve, 2)
- assert(AreaUnderCurve.of(rddCurve) === 0.0)
+ assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 94db1dc183230..a733f88b60b80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals
+import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
- // TODO: move utility functions to TestingUtils.
+ def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
- def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
- actual.zip(expected).forall { case (x1, x2) =>
- x1.almostEquals(x2)
- }
- }
-
- def elementsAlmostEqual(
- actual: Seq[(Double, Double)],
- expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
- actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
- x1.almostEquals(x2) && y1.almostEquals(y2)
- }
- }
+ def cond2(x: ((Double, Double), (Double, Double))): Boolean =
+ (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
@@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val prCurve = Seq((0.0, 1.0)) ++ pr
- val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
- assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
- assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
- assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
- assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
- assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
- assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
- assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
- assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
- assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
+
+ assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
+ assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
+ assert(metrics.pr().collect().zip(prCurve).forall(cond2))
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
+ assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
+ assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
+ assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
+ assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index dfb2eb7f0d14e..bf040110e228b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
object GradientDescentSuite {
@@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept)
- def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
- math.abs(x - y) / (math.abs(y) + 1e-15) < tol
- }
-
- assert(compareDouble(
- loss1(0),
- loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
- math.pow(initialWeightsWithIntercept(1), 2)) / 2),
+ assert(
+ loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
+ math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5,
"""For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""")
assert(
- compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) &&
- compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)),
+ (newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) &&
+ (newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5),
"The different between newWeights with/without regularization " +
"should be initialWeightsWithIntercept.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index ff414742e8393..5f4c24115ac80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
@@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
lazy val dataRDD = sc.parallelize(data, 2).cache()
- def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
- math.abs(x - y) / (math.abs(y) + 1e-15) < tol
- }
-
test("LBFGS loss should be decreasing and match the result of Gradient Descent.") {
val regParam = 0
@@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
miniBatchFrac,
initialWeightsWithIntercept)
- assert(compareDouble(lossGD(0), lossLBFGS(0)),
+ assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
"The first losses of LBFGS and GD should be the same.")
// The 2% difference here is based on observation, but is not theoretically guaranteed.
- assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02),
+ assert(lossGD.last ~= lossLBFGS.last relTol 0.02,
"The last losses of LBFGS and GD should be within 2% difference.")
- assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
- compareDouble(weightLBFGS(1), weightGD(1), 0.02),
+ assert(
+ (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
"The weight differences between LBFGS and GD should be within 2%.")
}
@@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
initialWeightsWithIntercept)
// for class LBFGS and the optimize method, we only look at the weights
- assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
- compareDouble(weightLBFGS(1), weightGD(1), 0.02),
+ assert(
+ (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
"The weight differences between LBFGS and GD should be within 2%.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index bbf385229081a..b781a6aed9a8c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -21,7 +21,9 @@ import scala.util.Random
import org.scalatest.FunSuite
-import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}
+import org.jblas.{DoubleMatrix, SimpleBlas}
+
+import org.apache.spark.mllib.util.TestingUtils._
class NNLSSuite extends FunSuite {
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
@@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite {
val ws = NNLS.createWorkspace(n)
val x = NNLS.solve(ata, atb, ws)
for (i <- 0 until n) {
- assert(Math.abs(x(i) - goodx(i)) < 1e-3)
+ assert(x(i) ~== goodx(i) absTol 1E-3)
assert(x(i) >= 0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 4b7b019d820b4..db13f142df517 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0))
- assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+ assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
- assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+ assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
- assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+ assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
- assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+ assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
@@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
.add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
- assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+ assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
- assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+ assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
- assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+ assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
- assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+ assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
@@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
.add(Vectors.dense(1.7, -0.6, 0.0))
.add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
- assert(summarizer.mean.almostEquals(
- Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+ assert(summarizer.mean ~==
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
- assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+ assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
- assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+ assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
- assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer.variance.almostEquals(
- Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+ assert(summarizer.variance ~==
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 6)
}
@@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
val summarizer = summarizer1.merge(summarizer2)
- assert(summarizer.mean.almostEquals(
- Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+ assert(summarizer.mean ~==
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
- assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+ assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
- assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+ assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
- assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer.variance.almostEquals(
- Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+ assert(summarizer.variance ~==
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 6)
}
@@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer)
assert(summarizer3.count === 0)
- assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+ assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
- assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+ assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
- assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+ assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
- assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+ assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
- assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+ assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
- assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+ assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
- assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+ assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+ assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
- assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+ assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
- assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+ assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
index 64b1ba7527183..29cc42d8cbea7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
@@ -18,28 +18,155 @@
package org.apache.spark.mllib.util
import org.apache.spark.mllib.linalg.Vector
+import org.scalatest.exceptions.TestFailedException
object TestingUtils {
+ val ABS_TOL_MSG = " using absolute tolerance"
+ val REL_TOL_MSG = " using relative tolerance"
+
+ /**
+ * Private helper function for comparing two values using relative tolerance.
+ * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue,
+ * the relative tolerance is meaningless, so the exception will be raised to warn users.
+ */
+ private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ val absX = math.abs(x)
+ val absY = math.abs(y)
+ val diff = math.abs(x - y)
+ if (x == y) {
+ true
+ } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) {
+ throw new TestFailedException(
+ s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0)
+ } else {
+ diff < eps * math.min(absX, absY)
+ }
+ }
+
+ /**
+ * Private helper function for comparing two values using absolute tolerance.
+ */
+ private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ math.abs(x - y) < eps
+ }
+
+ case class CompareDoubleRightSide(
+ fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String)
+
+ /**
+ * Implicit class for comparing two double values using relative tolerance or absolute tolerance.
+ */
implicit class DoubleWithAlmostEquals(val x: Double) {
- // An improved version of AlmostEquals would always divide by the larger number.
- // This will avoid the problem of diving by zero.
- def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = {
- if(x == y) {
- true
- } else if(math.abs(x) > math.abs(y)) {
- math.abs(x - y) / math.abs(x) < epsilon
- } else {
- math.abs(x - y) / math.abs(y) < epsilon
+
+ /**
+ * When the difference of two values are within eps, returns true; otherwise, returns false.
+ */
+ def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps)
+
+ /**
+ * When the difference of two values are within eps, returns false; otherwise, returns true.
+ */
+ def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps)
+
+ /**
+ * Throws exception when the difference of two values are NOT within eps;
+ * otherwise, returns true.
+ */
+ def ~==(r: CompareDoubleRightSide): Boolean = {
+ if (!r.fun(x, r.y, r.eps)) {
+ throw new TestFailedException(
+ s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0)
}
+ true
}
+
+ /**
+ * Throws exception when the difference of two values are within eps; otherwise, returns true.
+ */
+ def !~==(r: CompareDoubleRightSide): Boolean = {
+ if (r.fun(x, r.y, r.eps)) {
+ throw new TestFailedException(
+ s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0)
+ }
+ true
+ }
+
+ /**
+ * Comparison using absolute tolerance.
+ */
+ def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison,
+ x, eps, ABS_TOL_MSG)
+
+ /**
+ * Comparison using relative tolerance.
+ */
+ def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison,
+ x, eps, REL_TOL_MSG)
+
+ override def toString = x.toString
}
+ case class CompareVectorRightSide(
+ fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String)
+
+ /**
+ * Implicit class for comparing two vectors using relative tolerance or absolute tolerance.
+ */
implicit class VectorWithAlmostEquals(val x: Vector) {
- def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = {
- x.toArray.corresponds(y.toArray) {
- _.almostEquals(_, epsilon)
+
+ /**
+ * When the difference of two vectors are within eps, returns true; otherwise, returns false.
+ */
+ def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps)
+
+ /**
+ * When the difference of two vectors are within eps, returns false; otherwise, returns true.
+ */
+ def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps)
+
+ /**
+ * Throws exception when the difference of two vectors are NOT within eps;
+ * otherwise, returns true.
+ */
+ def ~==(r: CompareVectorRightSide): Boolean = {
+ if (!r.fun(x, r.y, r.eps)) {
+ throw new TestFailedException(
+ s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0)
}
+ true
}
+
+ /**
+ * Throws exception when the difference of two vectors are within eps; otherwise, returns true.
+ */
+ def !~==(r: CompareVectorRightSide): Boolean = {
+ if (r.fun(x, r.y, r.eps)) {
+ throw new TestFailedException(
+ s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0)
+ }
+ true
+ }
+
+ /**
+ * Comparison using absolute tolerance.
+ */
+ def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
+ (x: Vector, y: Vector, eps: Double) => {
+ x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
+ }, x, eps, ABS_TOL_MSG)
+
+ /**
+ * Comparison using relative tolerance. Note that comparing against sparse vector
+ * with elements having value of zero will raise exception because it involves with
+ * comparing against zero.
+ */
+ def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
+ (x: Vector, y: Vector, eps: Double) => {
+ x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
+ }, x, eps, REL_TOL_MSG)
+
+ override def toString = x.toString
}
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
new file mode 100644
index 0000000000000..b0ecb33c28483
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.scalatest.FunSuite
+import org.apache.spark.mllib.util.TestingUtils._
+import org.scalatest.exceptions.TestFailedException
+
+class TestingUtilsSuite extends FunSuite {
+
+ test("Comparing doubles using relative error.") {
+
+ assert(23.1 ~== 23.52 relTol 0.02)
+ assert(23.1 ~== 22.74 relTol 0.02)
+ assert(23.1 ~= 23.52 relTol 0.02)
+ assert(23.1 ~= 22.74 relTol 0.02)
+ assert(!(23.1 !~= 23.52 relTol 0.02))
+ assert(!(23.1 !~= 22.74 relTol 0.02))
+
+ // Should throw exception with message when test fails.
+ intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02)
+ intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02)
+ intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02)
+ intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02)
+
+ assert(23.1 !~== 23.63 relTol 0.02)
+ assert(23.1 !~== 22.34 relTol 0.02)
+ assert(23.1 !~= 23.63 relTol 0.02)
+ assert(23.1 !~= 22.34 relTol 0.02)
+ assert(!(23.1 ~= 23.63 relTol 0.02))
+ assert(!(23.1 ~= 22.34 relTol 0.02))
+
+ // Comparing against zero should fail the test and throw exception with message
+ // saying that the relative error is meaningless in this situation.
+ intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032)
+ intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032)
+ intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032)
+ intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032)
+ intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032)
+ intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032)
+ intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032)
+ intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032)
+
+ // Comparisons of numbers very close to zero.
+ assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01)
+ assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01)
+
+ assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012)
+ assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012)
+ }
+
+ test("Comparing doubles using absolute error.") {
+
+ assert(17.8 ~== 17.99 absTol 0.2)
+ assert(17.8 ~== 17.61 absTol 0.2)
+ assert(17.8 ~= 17.99 absTol 0.2)
+ assert(17.8 ~= 17.61 absTol 0.2)
+ assert(!(17.8 !~= 17.99 absTol 0.2))
+ assert(!(17.8 !~= 17.61 absTol 0.2))
+
+ // Should throw exception with message when test fails.
+ intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2)
+ intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2)
+ intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2)
+ intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2)
+
+ assert(17.8 !~== 18.01 absTol 0.2)
+ assert(17.8 !~== 17.59 absTol 0.2)
+ assert(17.8 !~= 18.01 absTol 0.2)
+ assert(17.8 !~= 17.59 absTol 0.2)
+ assert(!(17.8 ~= 18.01 absTol 0.2))
+ assert(!(17.8 ~= 17.59 absTol 0.2))
+
+ // Comparisons of numbers very close to zero, and both side of zeros
+ assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+
+ assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ }
+
+ test("Comparing vectors using relative error.") {
+
+ //Comparisons of two dense vectors
+ assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
+ assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
+ assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
+ assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
+ assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01))
+ assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01))
+
+ // Should throw exception with message when test fails.
+ intercept[TestFailedException](
+ Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
+
+ intercept[TestFailedException](
+ Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
+
+ // Comparing against zero should fail the test and throw exception with message
+ // saying that the relative error is meaningless in this situation.
+ intercept[TestFailedException](
+ Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01)
+
+ intercept[TestFailedException](
+ Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01)
+
+ // Comparisons of two sparse vectors
+ assert(Vectors.dense(Array(3.1, 3.5)) ~==
+ Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01)
+
+ assert(Vectors.dense(Array(3.1, 3.5)) !~==
+ Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01)
+ }
+
+ test("Comparing vectors using absolute error.") {
+
+ //Comparisons of two dense vectors
+ assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~==
+ Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
+
+ assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~==
+ Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
+
+ assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~=
+ Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
+
+ assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~=
+ Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
+
+ assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~=
+ Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6))
+
+ assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~=
+ Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6))
+
+ // Should throw exception with message when test fails.
+ intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~==
+ Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
+
+ intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~==
+ Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)
+
+ // Comparisons of two sparse vectors
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~==
+ Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6)
+
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~==
+ Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
+
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~==
+ Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6)
+
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~==
+ Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
+
+ // Comparisons of a dense vector and a sparse vector
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~==
+ Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6)
+
+ assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~==
+ Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6)
+
+ assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~==
+ Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6)
+ }
+}
diff --git a/pom.xml b/pom.xml
index d2e6b3c0ed5a4..93ef3b91b5bce 100644
--- a/pom.xml
+++ b/pom.xml
@@ -252,9 +252,9 @@
3.3.2
- commons-codec
- commons-codec
- 1.5
+ commons-codec
+ commons-codec
+ 1.5
com.google.code.findbugs
@@ -1139,5 +1139,15 @@
+
+ hive-thriftserver
+
+ false
+
+
+ sql/hive-thriftserver
+
+
+
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 62576f84dd031..1629bc2cba8ba 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -30,11 +30,11 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming,
- streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) =
- Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql",
- "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
- "streaming-zeromq").map(ProjectRef(buildLocation, _))
+ val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql,
+ streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) =
+ Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
+ "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt",
+ "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) =
Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl")
@@ -100,7 +100,7 @@ object SparkBuild extends PomBuild {
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
case Some(v) =>
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1)))
- case _ =>
+ case _ =>
}
override val userPropertiesMap = System.getProperties.toMap
@@ -158,7 +158,7 @@ object SparkBuild extends PomBuild {
/* Enable Mima for all projects except spark, hive, catalyst, sql and repl */
// TODO: Add Sql to mima checks
- allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)).
+ allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)).
foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x))
/* Enable Assembly for all assembly projects */
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
new file mode 100755
index 0000000000000..8398e6f19b511
--- /dev/null
+++ b/sbin/start-thriftserver.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Shell script for starting the Spark SQL Thrift server
+
+# Enter posix mode for bash
+set -o posix
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
+ echo "Usage: ./sbin/start-thriftserver [options]"
+ $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ exit 0
+fi
+
+CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
+exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 6decde3fcd62d..531bfddbf237b 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -32,7 +32,7 @@
Spark Project Catalyst
http://spark.apache.org/
- catalyst
+ catalyst
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index 1d5f033f0d274..a357c6ffb8977 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(0, AttributeReference("key", StringType, nullable = false)()),
- BoundReference(1, AttributeReference("value", StringType, nullable = false)()))
+ BoundReference(1, AttributeReference("", StringType, nullable = false)()))
}
/**
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c309c43804d97..3a038a2db6173 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -32,7 +32,7 @@
Spark Project SQL
http://spark.apache.org/
- sql
+ sql
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 2b787e14f3f15..41920c00b5a2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -30,12 +30,13 @@ import scala.collection.JavaConverters._
* SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads).
*/
trait SQLConf {
+ import SQLConf._
/** ************************ Spark SQL Params/Hints ******************* */
// TODO: refactor so that these hints accessors don't pollute the name space of SQLContext?
/** Number of partitions to use for shuffle operators. */
- private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
+ private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
@@ -43,11 +44,10 @@ trait SQLConf {
* effectively disables auto conversion.
* Hive setting: hive.auto.convert.join.noconditionaltask.size.
*/
- private[spark] def autoConvertJoinSize: Int =
- get("spark.sql.auto.convert.join.size", "10000").toInt
+ private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt
/** A comma-separated list of table names marked to be broadcasted during joins. */
- private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "")
+ private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "")
/** ********************** SQLConf functionality methods ************ */
@@ -61,7 +61,7 @@ trait SQLConf {
def set(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
- require(value != null, s"value cannot be null for ${key}")
+ require(value != null, s"value cannot be null for $key")
settings.put(key, value)
}
@@ -90,3 +90,13 @@ trait SQLConf {
}
}
+
+object SQLConf {
+ val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
+ val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
+ val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
+
+ object Deprecated {
+ val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 98d2f89c8ae71..9293239131d52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution
+import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLConf, SQLContext}
trait Command {
/**
@@ -44,28 +45,53 @@ trait Command {
case class SetCommand(
key: Option[String], value: Option[String], output: Seq[Attribute])(
@transient context: SQLContext)
- extends LeafNode with Command {
+ extends LeafNode with Command with Logging {
- override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match {
+ override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match {
// Set value for key k.
case (Some(k), Some(v)) =>
- context.set(k, v)
- Array(k -> v)
+ if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
+ context.set(SQLConf.SHUFFLE_PARTITIONS, v)
+ Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
+ } else {
+ context.set(k, v)
+ Array(s"$k=$v")
+ }
// Query the value bound to key k.
case (Some(k), _) =>
- Array(k -> context.getOption(k).getOrElse(""))
+ // TODO (lian) This is just a workaround to make the Simba ODBC driver work.
+ // Should remove this once we get the ODBC driver updated.
+ if (k == "-v") {
+ val hiveJars = Seq(
+ "hive-exec-0.12.0.jar",
+ "hive-service-0.12.0.jar",
+ "hive-common-0.12.0.jar",
+ "hive-hwi-0.12.0.jar",
+ "hive-0.12.0.jar").mkString(":")
+
+ Array(
+ "system:java.class.path=" + hiveJars,
+ "system:sun.java.command=shark.SharkServer2")
+ }
+ else {
+ Array(s"$k=${context.getOption(k).getOrElse("")}")
+ }
// Query all key-value pairs that are set in the SQLConf of the context.
case (None, None) =>
- context.getAll
+ context.getAll.map { case (k, v) =>
+ s"$k=$v"
+ }
case _ =>
throw new IllegalArgumentException()
}
def execute(): RDD[Row] = {
- val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) }
+ val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) }
context.sparkContext.parallelize(rows, 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 08293f7f0ca30..1a58d73d9e7f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest {
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
- sql("set mapred.reduce.tasks=20")
- assert(get("mapred.reduce.tasks", "0") == "20")
- sql("set mapred.reduce.tasks = 40")
- assert(get("mapred.reduce.tasks", "0") == "40")
+ sql("set some.property=20")
+ assert(get("some.property", "0") == "20")
+ sql("set some.property = 40")
+ assert(get("some.property", "0") == "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
@@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest {
clear()
}
+ test("deprecated property") {
+ clear()
+ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
+ assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6736189c96d4b..de9e8aa4f62ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(testKey, testVal))
+ Seq(Seq(s"$testKey=$testVal"))
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(testKey, testVal),
- Seq(testKey + testKey, testVal + testVal))
+ Seq(s"$testKey=$testVal"),
+ Seq(s"${testKey + testKey}=${testVal + testVal}"))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(testKey, testVal))
+ Seq(Seq(s"$testKey=$testVal"))
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(nonexistentKey, ""))
+ Seq(Seq(s"$nonexistentKey="))
)
clear()
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
new file mode 100644
index 0000000000000..7fac90fdc596d
--- /dev/null
+++ b/sql/hive-thriftserver/pom.xml
@@ -0,0 +1,82 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent
+ 1.1.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-hive-thriftserver_2.10
+ jar
+ Spark Project Hive
+ http://spark.apache.org/
+
+ hive-thriftserver
+
+
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${project.version}
+
+
+ org.spark-project.hive
+ hive-cli
+ ${hive.version}
+
+
+ org.spark-project.hive
+ hive-jdbc
+ ${hive.version}
+
+
+ org.spark-project.hive
+ hive-beeline
+ ${hive.version}
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-deploy-plugin
+
+ true
+
+
+
+
+
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
new file mode 100644
index 0000000000000..ddbc2a79fb512
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService
+import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+/**
+ * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
+ * `HiveThriftServer2` thrift server.
+ */
+private[hive] object HiveThriftServer2 extends Logging {
+ var LOG = LogFactory.getLog(classOf[HiveServer2])
+
+ def main(args: Array[String]) {
+ val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2")
+
+ if (!optionsProcessor.process(args)) {
+ logger.warn("Error starting HiveThriftServer2 with given arguments")
+ System.exit(-1)
+ }
+
+ val ss = new SessionState(new HiveConf(classOf[SessionState]))
+
+ // Set all properties specified via command line.
+ val hiveConf: HiveConf = ss.getConf
+ hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) =>
+ logger.debug(s"HiveConf var: $k=$v")
+ }
+
+ SessionState.start(ss)
+
+ logger.info("Starting SparkContext")
+ SparkSQLEnv.init()
+ SessionState.start(ss)
+
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.sparkContext.stop()
+ }
+ }
+ )
+
+ try {
+ val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
+ server.init(hiveConf)
+ server.start()
+ logger.info("HiveThriftServer2 started")
+ } catch {
+ case e: Exception =>
+ logger.error("Error starting HiveThriftServer2", e)
+ System.exit(-1)
+ }
+ }
+}
+
+private[hive] class HiveThriftServer2(hiveContext: HiveContext)
+ extends HiveServer2
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
+ setSuperField(this, "cliService", sparkSqlCliService)
+ addService(sparkSqlCliService)
+
+ val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService)
+ setSuperField(this, "thriftCLIService", thriftCliService)
+ addService(thriftCliService)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
new file mode 100644
index 0000000000000..599294dfbb7d7
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+private[hive] object ReflectionUtils {
+ def setSuperField(obj : Object, fieldName: String, fieldValue: Object) {
+ setAncestorField(obj, 1, fieldName, fieldValue)
+ }
+
+ def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) {
+ val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.set(obj, fieldValue)
+ }
+
+ def getSuperField[T](obj: AnyRef, fieldName: String): T = {
+ getAncestorField[T](obj, 1, fieldName)
+ }
+
+ def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = {
+ val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.get(clazz).asInstanceOf[T]
+ }
+
+ def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = {
+ invoke(clazz, null, methodName, args: _*)
+ }
+
+ def invoke(
+ clazz: Class[_],
+ obj: AnyRef,
+ methodName: String,
+ args: (Class[_], AnyRef)*): AnyRef = {
+
+ val (types, values) = args.unzip
+ val method = clazz.getDeclaredMethod(methodName, types: _*)
+ method.setAccessible(true)
+ method.invoke(obj, values.toSeq: _*)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
new file mode 100755
index 0000000000000..27268ecb923e9
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io._
+import java.util.{ArrayList => JArrayList}
+
+import jline.{ConsoleReader, History}
+import org.apache.commons.lang.StringUtils
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.thrift.transport.TSocket
+
+import org.apache.spark.sql.Logging
+
+private[hive] object SparkSQLCLIDriver {
+ private var prompt = "spark-sql"
+ private var continuedPrompt = "".padTo(prompt.length, ' ')
+ private var transport:TSocket = _
+
+ installSignalHandler()
+
+ /**
+ * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(),
+ * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while
+ * a command is being processed by the current thread.
+ */
+ def installSignalHandler() {
+ HiveInterruptUtils.add(new HiveInterruptCallback {
+ override def interrupt() {
+ // Handle remote execution mode
+ if (SparkSQLEnv.sparkContext != null) {
+ SparkSQLEnv.sparkContext.cancelAllJobs()
+ } else {
+ if (transport != null) {
+ // Force closing of TCP connection upon session termination
+ transport.getSocket.close()
+ }
+ }
+ }
+ })
+ }
+
+ def main(args: Array[String]) {
+ val oproc = new OptionsProcessor()
+ if (!oproc.process_stage1(args)) {
+ System.exit(1)
+ }
+
+ // NOTE: It is critical to do this here so that log4j is reinitialized
+ // before any of the other core hive classes are loaded
+ var logInitFailed = false
+ var logInitDetailMessage: String = null
+ try {
+ logInitDetailMessage = LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException =>
+ logInitFailed = true
+ logInitDetailMessage = e.getMessage
+ }
+
+ val sessionState = new CliSessionState(new HiveConf(classOf[SessionState]))
+
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ if (!oproc.process_stage2(sessionState)) {
+ System.exit(2)
+ }
+
+ if (!sessionState.getIsSilent) {
+ if (logInitFailed) System.err.println(logInitDetailMessage)
+ else SessionState.getConsole.printInfo(logInitDetailMessage)
+ }
+
+ // Set all properties specified via command line.
+ val conf: HiveConf = sessionState.getConf
+ sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] =>
+ conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ sessionState.getOverriddenConfigurations.put(
+ item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ }
+
+ SessionState.start(sessionState)
+
+ // Clean up after we exit
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.stop()
+ }
+ }
+ )
+
+ // "-h" option has been passed, so connect to Hive thrift server.
+ if (sessionState.getHost != null) {
+ sessionState.connect()
+ if (sessionState.isRemoteMode) {
+ prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt
+ continuedPrompt = "".padTo(prompt.length, ' ')
+ }
+ }
+
+ if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) {
+ // Hadoop-20 and above - we need to augment classpath using hiveconf
+ // components.
+ // See also: code in ExecDriver.java
+ var loader = conf.getClassLoader
+ val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
+ if (StringUtils.isNotBlank(auxJars)) {
+ loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ","))
+ }
+ conf.setClassLoader(loader)
+ Thread.currentThread().setContextClassLoader(loader)
+ }
+
+ val cli = new SparkSQLCLIDriver
+ cli.setHiveVariables(oproc.getHiveVariables)
+
+ // TODO work around for set the log output to console, because the HiveContext
+ // will set the output into an invalid buffer.
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ // Execute -i init files (always in silent mode)
+ cli.processInitFiles(sessionState)
+
+ if (sessionState.execString != null) {
+ System.exit(cli.processLine(sessionState.execString))
+ }
+
+ try {
+ if (sessionState.fileName != null) {
+ System.exit(cli.processFile(sessionState.fileName))
+ }
+ } catch {
+ case e: FileNotFoundException =>
+ System.err.println(s"Could not open input file for reading. (${e.getMessage})")
+ System.exit(3)
+ }
+
+ val reader = new ConsoleReader()
+ reader.setBellEnabled(false)
+ // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
+ CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e))
+
+ val historyDirectory = System.getProperty("user.home")
+
+ try {
+ if (new File(historyDirectory).exists()) {
+ val historyFile = historyDirectory + File.separator + ".hivehistory"
+ reader.setHistory(new History(new File(historyFile)))
+ } else {
+ System.err.println("WARNING: Directory for Hive history file: " + historyDirectory +
+ " does not exist. History will not be available during this session.")
+ }
+ } catch {
+ case e: Exception =>
+ System.err.println("WARNING: Encountered an error while trying to initialize Hive's " +
+ "history file. History will not be available during this session.")
+ System.err.println(e.getMessage)
+ }
+
+ val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport")
+ clientTransportTSocketField.setAccessible(true)
+
+ transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket]
+
+ var ret = 0
+ var prefix = ""
+ val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb",
+ classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState)
+
+ def promptWithCurrentDB = s"$prompt$currentDB"
+ def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic(
+ classOf[CliDriver], "spacesForString", classOf[String] -> currentDB)
+
+ var currentPrompt = promptWithCurrentDB
+ var line = reader.readLine(currentPrompt + "> ")
+
+ while (line != null) {
+ if (prefix.nonEmpty) {
+ prefix += '\n'
+ }
+
+ if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
+ line = prefix + line
+ ret = cli.processLine(line, true)
+ prefix = ""
+ currentPrompt = promptWithCurrentDB
+ } else {
+ prefix = prefix + line
+ currentPrompt = continuedPromptWithDBSpaces
+ }
+
+ line = reader.readLine(currentPrompt + "> ")
+ }
+
+ sessionState.close()
+
+ System.exit(ret)
+ }
+}
+
+private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
+ private val sessionState = SessionState.get().asInstanceOf[CliSessionState]
+
+ private val LOG = LogFactory.getLog("CliDriver")
+
+ private val console = new SessionState.LogHelper(LOG)
+
+ private val conf: Configuration =
+ if (sessionState != null) sessionState.getConf else new Configuration()
+
+ // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver
+ // because the Hive unit tests do not go through the main() code path.
+ if (!sessionState.isRemoteMode) {
+ SparkSQLEnv.init()
+ }
+
+ override def processCmd(cmd: String): Int = {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ if (cmd_trimmed.toLowerCase.equals("quit") ||
+ cmd_trimmed.toLowerCase.equals("exit") ||
+ tokens(0).equalsIgnoreCase("source") ||
+ cmd_trimmed.startsWith("!") ||
+ tokens(0).toLowerCase.equals("list") ||
+ sessionState.isRemoteMode) {
+ val start = System.currentTimeMillis()
+ super.processCmd(cmd)
+ val end = System.currentTimeMillis()
+ val timeTaken: Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds")
+ 0
+ } else {
+ var ret = 0
+ val hconf = conf.asInstanceOf[HiveConf]
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf)
+
+ if (proc != null) {
+ if (proc.isInstanceOf[Driver]) {
+ val driver = new SparkSQLDriver
+
+ driver.init()
+ val out = sessionState.out
+ val start:Long = System.currentTimeMillis()
+ if (sessionState.getIsVerbose) {
+ out.println(cmd)
+ }
+
+ ret = driver.run(cmd).getResponseCode
+ if (ret != 0) {
+ driver.close()
+ return ret
+ }
+
+ val res = new JArrayList[String]()
+
+ if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) {
+ // Print the column names.
+ Option(driver.getSchema.getFieldSchemas).map { fields =>
+ out.println(fields.map(_.getName).mkString("\t"))
+ }
+ }
+
+ try {
+ while (!out.checkError() && driver.getResults(res)) {
+ res.foreach(out.println)
+ res.clear()
+ }
+ } catch {
+ case e:IOException =>
+ console.printError(
+ s"""Failed with exception ${e.getClass.getName}: ${e.getMessage}
+ |${org.apache.hadoop.util.StringUtils.stringifyException(e)}
+ """.stripMargin)
+ ret = 1
+ }
+
+ val cret = driver.close()
+ if (ret == 0) {
+ ret = cret
+ }
+
+ val end = System.currentTimeMillis()
+ if (end > start) {
+ val timeTaken:Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds", null)
+ }
+
+ // Destroy the driver to release all the locks.
+ driver.destroy()
+ } else {
+ if (sessionState.getIsVerbose) {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ }
+ ret = proc.run(cmd_1).getResponseCode
+ }
+ }
+ ret
+ }
+ }
+}
+
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
new file mode 100644
index 0000000000000..42cbf363b274f
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io.IOException
+import java.util.{List => JList}
+import javax.security.auth.login.LoginException
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.hive.service.Service.STATE
+import org.apache.hive.service.auth.HiveAuthFactory
+import org.apache.hive.service.cli.CLIService
+import org.apache.hive.service.{AbstractService, Service, ServiceException}
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
+ extends CLIService
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
+ setSuperField(this, "sessionManager", sparkSqlSessionManager)
+ addService(sparkSqlSessionManager)
+
+ try {
+ HiveAuthFactory.loginFromKeytab(hiveConf)
+ val serverUserName = ShimLoader.getHadoopShims
+ .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf))
+ setSuperField(this, "serverUserName", serverUserName)
+ } catch {
+ case e @ (_: IOException | _: LoginException) =>
+ throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
+ }
+
+ initCompositeService(hiveConf)
+ }
+}
+
+private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
+ def initCompositeService(hiveConf: HiveConf) {
+ // Emulating `CompositeService.init(hiveConf)`
+ val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
+ serviceList.foreach(_.init(hiveConf))
+
+ // Emulating `AbstractService.init(hiveConf)`
+ invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
+ setAncestorField(this, 3, "hiveConf", hiveConf)
+ invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
+ getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
new file mode 100644
index 0000000000000..5202aa9903e03
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.commons.lang.exception.ExceptionUtils
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+
+private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext)
+ extends Driver with Logging {
+
+ private var tableSchema: Schema = _
+ private var hiveResponse: Seq[String] = _
+
+ override def init(): Unit = {
+ }
+
+ private def getResultSetSchema(query: context.QueryExecution): Schema = {
+ val analyzed = query.analyzed
+ logger.debug(s"Result Schema: ${analyzed.output}")
+ if (analyzed.output.size == 0) {
+ new Schema(new FieldSchema("Response code", "string", "") :: Nil, null)
+ } else {
+ val fieldSchemas = analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+
+ new Schema(fieldSchemas, null)
+ }
+ }
+
+ override def run(command: String): CommandProcessorResponse = {
+ val execution = context.executePlan(context.hql(command).logicalPlan)
+
+ // TODO unify the error code
+ try {
+ hiveResponse = execution.stringResult()
+ tableSchema = getResultSetSchema(execution)
+ new CommandProcessorResponse(0)
+ } catch {
+ case cause: Throwable =>
+ logger.error(s"Failed in [$command]", cause)
+ new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null)
+ }
+ }
+
+ override def close(): Int = {
+ hiveResponse = null
+ tableSchema = null
+ 0
+ }
+
+ override def getSchema: Schema = tableSchema
+
+ override def getResults(res: JArrayList[String]): Boolean = {
+ if (hiveResponse == null) {
+ false
+ } else {
+ res.addAll(hiveResponse)
+ hiveResponse = null
+ true
+ }
+ }
+
+ override def destroy() {
+ super.destroy()
+ hiveResponse = null
+ tableSchema = null
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
new file mode 100644
index 0000000000000..451c3bd7b9352
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import org.apache.hadoop.hive.ql.session.SessionState
+
+import org.apache.spark.scheduler.{SplitInfo, StatsReportListener}
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/** A singleton object for the master program. The slaves should not access this. */
+private[hive] object SparkSQLEnv extends Logging {
+ logger.debug("Initializing SparkSQLEnv")
+
+ var hiveContext: HiveContext = _
+ var sparkContext: SparkContext = _
+
+ def init() {
+ if (hiveContext == null) {
+ sparkContext = new SparkContext(new SparkConf()
+ .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
+
+ sparkContext.addSparkListener(new StatsReportListener())
+
+ hiveContext = new HiveContext(sparkContext) {
+ @transient override lazy val sessionState = SessionState.get()
+ @transient override lazy val hiveconf = sessionState.getConf
+ }
+ }
+ }
+
+ /** Cleans up and shuts down the Spark SQL environments. */
+ def stop() {
+ logger.debug("Shutting down Spark SQL Environment")
+ // Stop the SparkContext
+ if (SparkSQLEnv.sparkContext != null) {
+ sparkContext.stop()
+ sparkContext = null
+ hiveContext = null
+ }
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
new file mode 100644
index 0000000000000..6b3275b4eaf04
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.concurrent.Executors
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.service.cli.session.SessionManager
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
+
+private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
+ extends SessionManager
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
+ setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
+ getAncestorField[Log](this, 3, "LOG").info(
+ s"HiveServer2: Async execution pool size $backgroundPoolSize")
+
+ val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+ setSuperField(this, "operationManager", sparkSqlOperationManager)
+ addService(sparkSqlOperationManager)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
new file mode 100644
index 0000000000000..a4e1f3e762e89
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver.server
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.math.{random, round}
+
+import java.sql.Timestamp
+import java.util.{Map => JMap}
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hive.service.cli._
+import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
+import org.apache.hive.service.cli.session.HiveSession
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow}
+
+/**
+ * Executes queries using Spark SQL, and maintains a list of handles to active queries.
+ */
+class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging {
+ val handleToOperation = ReflectionUtils
+ .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
+
+ override def newExecuteStatementOperation(
+ parentSession: HiveSession,
+ statement: String,
+ confOverlay: JMap[String, String],
+ async: Boolean): ExecuteStatementOperation = synchronized {
+
+ val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) {
+ private var result: SchemaRDD = _
+ private var iter: Iterator[SparkRow] = _
+ private var dataTypes: Array[DataType] = _
+
+ def close(): Unit = {
+ // RDDs will be cleaned automatically upon garbage collection.
+ logger.debug("CLOSING")
+ }
+
+ def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
+ if (!iter.hasNext) {
+ new RowSet()
+ } else {
+ val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No.
+ var curRow = 0
+ var rowSet = new ArrayBuffer[Row](maxRows)
+
+ while (curRow < maxRows && iter.hasNext) {
+ val sparkRow = iter.next()
+ val row = new Row()
+ var curCol = 0
+
+ while (curCol < sparkRow.length) {
+ dataTypes(curCol) match {
+ case StringType =>
+ row.addString(sparkRow(curCol).asInstanceOf[String])
+ case IntegerType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
+ case BooleanType =>
+ row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
+ case DoubleType =>
+ row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
+ case FloatType =>
+ row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
+ case DecimalType =>
+ val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
+ row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+ case LongType =>
+ row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
+ case ByteType =>
+ row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
+ case ShortType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
+ case TimestampType =>
+ row.addColumnValue(
+ ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
+ case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+ val hiveString = result
+ .queryExecution
+ .asInstanceOf[HiveContext#QueryExecution]
+ .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
+ row.addColumnValue(ColumnValue.stringValue(hiveString))
+ }
+ curCol += 1
+ }
+ rowSet += row
+ curRow += 1
+ }
+ new RowSet(rowSet, 0)
+ }
+ }
+
+ def getResultSetSchema: TableSchema = {
+ logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}")
+ if (result.queryExecution.analyzed.output.size == 0) {
+ new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
+ } else {
+ val schema = result.queryExecution.analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+ new TableSchema(schema)
+ }
+ }
+
+ def run(): Unit = {
+ logger.info(s"Running query '$statement'")
+ setState(OperationState.RUNNING)
+ try {
+ result = hiveContext.hql(statement)
+ logger.debug(result.queryExecution.toString())
+ val groupId = round(random * 1000000).toString
+ hiveContext.sparkContext.setJobGroup(groupId, statement)
+ iter = result.queryExecution.toRdd.toLocalIterator
+ dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
+ setHasResultSet(true)
+ } catch {
+ // Actually do need to catch Throwable as some failures don't inherit from Exception and
+ // HiveServer will silently swallow them.
+ case e: Throwable =>
+ logger.error("Error executing query:",e)
+ throw new HiveSQLException(e.toString)
+ }
+ setState(OperationState.FINISHED)
+ }
+ }
+
+ handleToOperation.put(operation.getHandle, operation)
+ operation
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
new file mode 100644
index 0000000000000..850f8014b6f05
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
@@ -0,0 +1,5 @@
+238val_238
+86val_86
+311val_311
+27val_27
+165val_165
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
new file mode 100644
index 0000000000000..69f19f826a802
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, InputStreamReader, PrintWriter}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils {
+ val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli")
+ val METASTORE_PATH = TestUtils.getMetastorePath("cli")
+
+ override def beforeAll() {
+ val pb = new ProcessBuilder(
+ "../../bin/spark-sql",
+ "--master",
+ "local",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
+
+ process = pb.start()
+ outputWriter = new PrintWriter(process.getOutputStream, true)
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "spark-sql>")
+ }
+
+ override def afterAll() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("simple commands") {
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ executeQuery("create table hive_test1(key int, val string);")
+ executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
+ executeQuery("cache table hive_test1", "Time taken")
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
new file mode 100644
index 0000000000000..fe3403b3292ec
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent._
+
+import java.io.{BufferedReader, InputStreamReader}
+import java.net.ServerSocket
+import java.sql.{Connection, DriverManager, Statement}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+
+/**
+ * Test for the HiveThriftServer2 using JDBC.
+ */
+class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging {
+
+ val WAREHOUSE_PATH = getTempFilePath("warehouse")
+ val METASTORE_PATH = getTempFilePath("metastore")
+
+ val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver"
+ val TABLE = "test"
+ val HOST = "localhost"
+ val PORT = {
+ // Let the system to choose a random available port to avoid collision with other parallel
+ // builds.
+ val socket = new ServerSocket(0)
+ val port = socket.getLocalPort
+ socket.close()
+ port
+ }
+
+ // If verbose is true, the test program will print all outputs coming from the Hive Thrift server.
+ val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean
+
+ Class.forName(DRIVER_NAME)
+
+ override def beforeAll() { launchServer() }
+
+ override def afterAll() { stopServer() }
+
+ private def launchServer(args: Seq[String] = Seq.empty) {
+ // Forking a new process to start the Hive Thrift server. The reason to do this is it is
+ // hard to clean up Hive resources entirely, so we just start a new process and kill
+ // that process for cleanup.
+ val defaultArgs = Seq(
+ "../../sbin/start-thriftserver.sh",
+ "--master local",
+ "--hiveconf",
+ "hive.root.logger=INFO,console",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH")
+ val pb = new ProcessBuilder(defaultArgs ++ args)
+ val environment = pb.environment()
+ environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString)
+ environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST)
+ process = pb.start()
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "ThriftBinaryCLIService listening on")
+
+ // Spawn a thread to read the output from the forked process.
+ // Note that this is necessary since in some configurations, log4j could be blocked
+ // if its output to stderr are not read, and eventually blocking the entire test suite.
+ future {
+ while (true) {
+ val stdout = readFrom(inputReader)
+ val stderr = readFrom(errorReader)
+ if (VERBOSE && stdout.length > 0) {
+ println(stdout)
+ }
+ if (VERBOSE && stderr.length > 0) {
+ println(stderr)
+ }
+ Thread.sleep(50)
+ }
+ }
+ }
+
+ private def stopServer() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("test query execution against a Hive Thrift server") {
+ Thread.sleep(5 * 1000)
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ val stmt = createStatement()
+ stmt.execute("DROP TABLE IF EXISTS test")
+ stmt.execute("DROP TABLE IF EXISTS test_cached")
+ stmt.execute("CREATE TABLE test(key int, val string)")
+ stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
+ stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+ stmt.execute("CACHE TABLE test_cached")
+
+ var rs = stmt.executeQuery("select count(*) from test")
+ rs.next()
+ assert(rs.getInt(1) === 5)
+
+ rs = stmt.executeQuery("select count(*) from test_cached")
+ rs.next()
+ assert(rs.getInt(1) === 4)
+
+ stmt.close()
+ }
+
+ def getConnection: Connection = {
+ val connectURI = s"jdbc:hive2://localhost:$PORT/"
+ DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
+ }
+
+ def createStatement(): Statement = getConnection.createStatement()
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
new file mode 100644
index 0000000000000..bb2242618fbef
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import org.apache.hadoop.hive.common.LogUtils
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+
+object TestUtils {
+ val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
+
+ def getWarehousePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
+ timestamp.format(new Date)
+ }
+
+ def getMetastorePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
+ timestamp.format(new Date)
+ }
+
+ // Dummy function for initialize the log4j properties.
+ def init() { }
+
+ // initialize log4j
+ try {
+ LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException => // Ignore the error.
+ }
+}
+
+trait TestUtils {
+ var process : Process = null
+ var outputWriter : PrintWriter = null
+ var inputReader : BufferedReader = null
+ var errorReader : BufferedReader = null
+
+ def executeQuery(
+ cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
+ println("Executing: " + cmd + ", expecting output: " + outputMessage)
+ outputWriter.write(cmd + "\n")
+ outputWriter.flush()
+ waitForQuery(timeout, outputMessage)
+ }
+
+ protected def waitForQuery(timeout: Long, message: String): String = {
+ if (waitForOutput(errorReader, message, timeout)) {
+ Thread.sleep(500)
+ readOutput()
+ } else {
+ assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
+ null
+ }
+ }
+
+ // Wait for the specified str to appear in the output.
+ protected def waitForOutput(
+ reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
+ val startTime = System.currentTimeMillis
+ var out = ""
+ while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
+ out += readFrom(reader)
+ }
+ out.contains(str)
+ }
+
+ // Read stdout output and filter out garbage collection messages.
+ protected def readOutput(): String = {
+ val output = readFrom(inputReader)
+ // Remove GC Messages
+ val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
+ .mkString("\n")
+ filteredOutput
+ }
+
+ protected def readFrom(reader: BufferedReader): String = {
+ var out = ""
+ var c = 0
+ while (reader.ready) {
+ c = reader.read()
+ out += c.asInstanceOf[Char]
+ }
+ out
+ }
+
+ protected def getDataFile(name: String) = {
+ Thread.currentThread().getContextClassLoader.getResource(name)
+ }
+}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 1699ffe06ce15..93d00f7c37c9b 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -32,7 +32,7 @@
Spark Project Hive
http://spark.apache.org/
- hive
+ hive
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 201c85f3d501e..84d43eaeea51d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -255,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
ShortType, DecimalType, TimestampType, BinaryType)
- protected def toHiveString(a: (Any, DataType)): String = a match {
+ protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index e6ab68b563f8d..d18ccf8167487 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -610,7 +610,7 @@ private[hive] object HiveQl {
// TOK_DESTINATION means to overwrite the table.
val resultDestination =
(intoClause orElse destClause).getOrElse(sys.error("No destination found."))
- val overwrite = if (intoClause.isEmpty) true else false
+ val overwrite = intoClause.isEmpty
nodeToDest(
resultDestination,
withLimit,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index c3942578d6b5a..82c88280d7754 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -24,6 +24,8 @@ import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde2.Deserializer
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
+
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
@@ -31,13 +33,16 @@ import org.apache.spark.SerializableWritable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
+import org.apache.spark.sql.catalyst.types.DataType
+
/**
* A trait for subclasses that handle table scans.
*/
private[hive] sealed trait TableReader {
- def makeRDDForTable(hiveTable: HiveTable): RDD[_]
+ def makeRDDForTable(hiveTable: HiveTable): RDD[Row]
- def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_]
+ def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row]
}
@@ -46,7 +51,10 @@ private[hive] sealed trait TableReader {
* data warehouse directory.
*/
private[hive]
-class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext)
+class HadoopTableReader(
+ @transient attributes: Seq[Attribute],
+ @transient relation: MetastoreRelation,
+ @transient sc: HiveContext)
extends TableReader {
// Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
@@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
def hiveConf = _broadcastedHiveConf.value.value
- override def makeRDDForTable(hiveTable: HiveTable): RDD[_] =
+ override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] =
makeRDDForTable(
hiveTable,
- _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
+ relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
filterOpt = None)
/**
@@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
def makeRDDForTable(
hiveTable: HiveTable,
deserializerClass: Class[_ <: Deserializer],
- filterOpt: Option[PathFilter]): RDD[_] = {
+ filterOpt: Option[PathFilter]): RDD[Row] = {
assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")
// Create local references to member variables, so that the entire `this` object won't be
// serialized in the closure below.
- val tableDesc = _tableDesc
+ val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val tablePath = hiveTable.getPath
@@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
.asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
+ val attrsWithIndex = attributes.zipWithIndex
+ val mutableRow = new GenericMutableRow(attrsWithIndex.length)
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)
- // Deserialize each Writable to get the row value.
- iter.map {
- case v: Writable => deserializer.deserialize(v)
- case value =>
- sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}")
- }
+ HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
}
deserializedHadoopRDD
}
- override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = {
+ override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = {
val partitionToDeserializer = partitions.map(part =>
(part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
@@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
* subdirectory of each partition being read. If None, then all files are accepted.
*/
def makeRDDForPartitionedTable(
- partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]],
- filterOpt: Option[PathFilter]): RDD[_] = {
-
+ partitionToDeserializer: Map[HivePartition,
+ Class[_ <: Deserializer]],
+ filterOpt: Option[PathFilter]): RDD[Row] = {
val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) =>
val partDesc = Utilities.getPartitionDesc(partition)
val partPath = partition.getPartitionPath
@@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
}
// Create local references so that the outer object isn't serialized.
- val tableDesc = _tableDesc
+ val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val localDeserializer = partDeserializer
+ val mutableRow = new GenericMutableRow(attributes.length)
+
+ // split the attributes (output schema) into 2 categories:
+ // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
+ // index of the attribute in the output Row.
+ val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
+ relation.partitionKeys.indexOf(attr._1) >= 0
+ })
+
+ def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
+ partitionKeys.foreach { case (attr, ordinal) =>
+ // get partition key ordinal for a given attribute
+ val partOridinal = relation.partitionKeys.indexOf(attr)
+ row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
+ }
+ }
+ // fill the partition key for the given MutableRow Object
+ fillPartitionKeys(partValues, mutableRow)
val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
hivePartitionRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
- val rowWithPartArr = new Array[Object](2)
-
- // The update and deserializer initialization are intentionally
- // kept out of the below iter.map loop to save performance.
- rowWithPartArr.update(1, partValues)
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)
- // Map each tuple to a row object
- iter.map { value =>
- val deserializedRow = deserializer.deserialize(value)
- rowWithPartArr.update(0, deserializedRow)
- rowWithPartArr.asInstanceOf[Object]
- }
+ // fill the non partition key attributes
+ HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
}
}.toSeq
// Even if we don't use any partitions, we still need an empty RDD
if (hivePartitionRDDs.size == 0) {
- new EmptyRDD[Object](sc.sparkContext)
+ new EmptyRDD[Row](sc.sparkContext)
} else {
new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
}
@@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
// Only take the value (skip the key) because Hive works only with values.
rdd.map(_._2)
}
-
}
-private[hive] object HadoopTableReader {
+private[hive] object HadoopTableReader extends HiveInspectors {
/**
* Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to
* instantiate a HadoopRDD.
@@ -241,4 +254,40 @@ private[hive] object HadoopTableReader {
val bufferSize = System.getProperty("spark.buffer.size", "65536")
jobConf.set("io.file.buffer.size", bufferSize)
}
+
+ /**
+ * Transform the raw data(Writable object) into the Row object for an iterable input
+ * @param iter Iterable input which represented as Writable object
+ * @param deserializer Deserializer associated with the input writable object
+ * @param attrs Represents the row attribute names and its zero-based position in the MutableRow
+ * @param row reusable MutableRow object
+ *
+ * @return Iterable Row object that transformed from the given iterable input.
+ */
+ def fillObject(
+ iter: Iterator[Writable],
+ deserializer: Deserializer,
+ attrs: Seq[(Attribute, Int)],
+ row: GenericMutableRow): Iterator[Row] = {
+ val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
+ // get the field references according to the attributes(output of the reader) required
+ val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }
+
+ // Map each tuple to a row object
+ iter.map { value =>
+ val raw = deserializer.deserialize(value)
+ var idx = 0;
+ while (idx < fieldRefs.length) {
+ val fieldRef = fieldRefs(idx)._1
+ val fieldIdx = fieldRefs(idx)._2
+ val fieldValue = soi.getStructFieldData(raw, fieldRef)
+
+ row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())
+
+ idx += 1
+ }
+
+ row: Row
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index e7016fa16eea9..8920e2a76a27f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
-import org.apache.spark.util.MutablePair
/**
* :: DeveloperApi ::
@@ -50,8 +49,7 @@ case class HiveTableScan(
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
@transient val context: HiveContext)
- extends LeafNode
- with HiveInspectors {
+ extends LeafNode {
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
@@ -67,42 +65,7 @@ case class HiveTableScan(
}
@transient
- private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context)
-
- /**
- * The hive object inspector for this table, which can be used to extract values from the
- * serialized row representation.
- */
- @transient
- private[this] lazy val objectInspector =
- relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]
-
- /**
- * Functions that extract the requested attributes from the hive output. Partitioned values are
- * casted from string to its declared data type.
- */
- @transient
- protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = {
- attributes.map { a =>
- val ordinal = relation.partitionKeys.indexOf(a)
- if (ordinal >= 0) {
- val dataType = relation.partitionKeys(ordinal).dataType
- (_: Any, partitionKeys: Array[String]) => {
- castFromString(partitionKeys(ordinal), dataType)
- }
- } else {
- val ref = objectInspector.getAllStructFieldRefs
- .find(_.getFieldName == a.name)
- .getOrElse(sys.error(s"Can't find attribute $a"))
- val fieldObjectInspector = ref.getFieldObjectInspector
-
- (row: Any, _: Array[String]) => {
- val data = objectInspector.getStructFieldData(row, ref)
- unwrapData(data, fieldObjectInspector)
- }
- }
- }
- }
+ private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context)
private[this] def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
@@ -114,6 +77,7 @@ case class HiveTableScan(
val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")
if (attributes.size == relation.output.size) {
+ // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes`
ColumnProjectionUtils.setFullyReadColumns(hiveConf)
} else {
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
@@ -140,12 +104,6 @@ case class HiveTableScan(
addColumnMetadataToConf(context.hiveconf)
- private def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
- hadoopReader.makeRDDForTable(relation.hiveQlTable)
- } else {
- hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
- }
-
/**
* Prunes partitions not involve the query plan.
*
@@ -169,44 +127,10 @@ case class HiveTableScan(
}
}
- override def execute() = {
- inputRdd.mapPartitions { iterator =>
- if (iterator.isEmpty) {
- Iterator.empty
- } else {
- val mutableRow = new GenericMutableRow(attributes.length)
- val mutablePair = new MutablePair[Any, Array[String]]()
- val buffered = iterator.buffered
-
- // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
- // matching are avoided intentionally.
- val rowsAndPartitionKeys = buffered.head match {
- // With partition keys
- case _: Array[Any] =>
- buffered.map { case array: Array[Any] =>
- val deserializedRow = array(0)
- val partitionKeys = array(1).asInstanceOf[Array[String]]
- mutablePair.update(deserializedRow, partitionKeys)
- }
-
- // Without partition keys
- case _ =>
- val emptyPartitionKeys = Array.empty[String]
- buffered.map { deserializedRow =>
- mutablePair.update(deserializedRow, emptyPartitionKeys)
- }
- }
-
- rowsAndPartitionKeys.map { pair =>
- var i = 0
- while (i < attributes.length) {
- mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
- i += 1
- }
- mutableRow: Row
- }
- }
- }
+ override def execute() = if (!relation.hiveQlTable.isPartitioned) {
+ hadoopReader.makeRDDForTable(relation.hiveQlTable)
+ } else {
+ hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}
override def output = attributes
diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298
new file mode 100644
index 0000000000000..f369f21e1833f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298
@@ -0,0 +1,2 @@
+100 100 2010-01-01
+200 200 2010-01-02
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index a8623b64c656f..a022a1e2dc70e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -419,10 +419,10 @@ class HiveQuerySuite extends HiveComparisonTest {
hql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
- hql("set mapred.reduce.tasks=20")
- assert(get("mapred.reduce.tasks", "0") == "20")
- hql("set mapred.reduce.tasks = 40")
- assert(get("mapred.reduce.tasks", "0") == "40")
+ hql("set some.property=20")
+ assert(get("some.property", "0") == "20")
+ hql("set some.property = 40")
+ assert(get("some.property", "0") == "40")
hql(s"set $testKey=$testVal")
assert(get(testKey, "0") == testVal)
@@ -436,63 +436,61 @@ class HiveQuerySuite extends HiveComparisonTest {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
- rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet
clear()
// "set" itself returns all config variables currently specified in SQLConf.
assert(hql("SET").collect().size == 0)
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql(s"SET $testKey=$testVal"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql("SET"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
hql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(hql("SET"))
+ assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
+ hql(s"SET").collect().map(_.getString(0))
}
// "set key"
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql(s"SET $testKey"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey").collect().map(_.getString(0))
}
- assertResult(Set(nonexistentKey -> "")) {
- collectResults(hql(s"SET $nonexistentKey"))
+ assertResult(Array(s"$nonexistentKey=")) {
+ hql(s"SET $nonexistentKey").collect().map(_.getString(0))
}
// Assert that sql() should have the same effects as hql() by repeating the above using sql().
clear()
assert(sql("SET").collect().size == 0)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey=$testVal"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql("SET"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql("SET").collect().map(_.getString(0))
}
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(sql("SET"))
+ assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
+ sql("SET").collect().map(_.getString(0))
}
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql(s"SET $testKey").collect().map(_.getString(0))
}
- assertResult(Set(nonexistentKey -> "")) {
- collectResults(sql(s"SET $nonexistentKey"))
+ assertResult(Array(s"$nonexistentKey=")) {
+ sql(s"SET $nonexistentKey").collect().map(_.getString(0))
}
clear()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
new file mode 100644
index 0000000000000..bcb00f871d185
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.hive.test.TestHive
+
+class HiveTableScanSuite extends HiveComparisonTest {
+ // MINOR HACK: You must run a query before calling reset the first time.
+ TestHive.hql("SHOW TABLES")
+ TestHive.reset()
+
+ TestHive.hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING)
+ | ROW FORMAT SERDE
+ | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
+ | STORED AS RCFILE
+ """.stripMargin)
+ TestHive.hql("""FROM src
+ | INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01')
+ | SELECT 100,100 LIMIT 1
+ """.stripMargin)
+ TestHive.hql("""ALTER TABLE part_scan_test SET SERDE
+ | 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe'
+ """.stripMargin)
+ TestHive.hql("""FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02')
+ | SELECT 200,200 LIMIT 1
+ """.stripMargin)
+
+ createQueryTest("partition_based_table_scan_with_different_serde",
+ "SELECT * from part_scan_test", false)
+}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index f60697ce745b7..b99f306b8f2cc 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming_2.10
- streaming
+ streaming
jar
Spark Project Streaming
diff --git a/tools/pom.xml b/tools/pom.xml
index c0ee8faa7a615..97abb6b2b63e0 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -27,7 +27,7 @@
org.apache.spark
spark-tools_2.10
- tools
+ tools
jar
Spark Project Tools
diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml
index 5b13a1f002d6e..51744ece0412d 100644
--- a/yarn/alpha/pom.xml
+++ b/yarn/alpha/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-alpha
+ yarn-alpha
org.apache.spark
diff --git a/yarn/pom.xml b/yarn/pom.xml
index efb473aa1b261..3faaf053634d6 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -29,7 +29,7 @@
pom
Spark Project YARN Parent POM
- yarn
+ yarn
diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml
index ceaf9f9d71001..b6c8456d06684 100644
--- a/yarn/stable/pom.xml
+++ b/yarn/stable/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-stable
+ yarn-stable
org.apache.spark