Skip to content

Commit c9f8400

Browse files
lianchengmarmbrus
authored andcommitted
[SPARK-3791][SQL] Provides Spark version and Hive version in HiveThriftServer2
This PR overrides the `GetInfo` Hive Thrift API to provide correct version information. Another property `spark.sql.hive.version` is added to reveal the underlying Hive version. These are generally useful for Spark SQL ODBC driver providers. The Spark version information is extracted from the jar manifest. Also took the chance to remove the `SET -v` hack, which was a workaround for Simba ODBC driver connectivity. TODO - [x] Find a general way to figure out Hive (or even any dependency) version. This [blog post](http://blog.soebes.de/blog/2014/01/02/version-information-into-your-appas-with-maven/) suggests several methods to inspect application version. In the case of Spark, this can be tricky because the chosen method: 1. must applies to both Maven build and SBT build For Maven builds, we can retrieve the version information from the META-INF/maven directory within the assembly jar. But this doesn't work for SBT builds. 2. must not rely on the original jars of dependencies to extract specific dependency version, because Spark uses assembly jar. This implies we can't read Hive version from Hive jar files since standard Spark distribution doesn't include them. 3. should play well with `SPARK_PREPEND_CLASSES` to ease local testing during development. `SPARK_PREPEND_CLASSES` prevents classes to be loaded from the assembly jar, thus we can't locate the jar file and read its manifest. Given these, maybe the only reliable method is to generate a source file containing version information at build time. pwendell Do you have any suggestions from the perspective of the build process? **Update** Hive version is now retrieved from the newly introduced `HiveShim` object. Author: Cheng Lian <[email protected]> Author: Cheng Lian <[email protected]> Closes #2843 from liancheng/get-info and squashes the following commits: a873d0f [Cheng Lian] Updates test case 53f43cd [Cheng Lian] Retrieves underlying Hive verson via HiveShim 1d282b8 [Cheng Lian] Removes the Simba ODBC "SET -v" hack f857fce [Cheng Lian] Overrides Hive GetInfo Thrift API and adds Hive version property
1 parent 495a132 commit c9f8400

File tree

7 files changed

+173
-112
lines changed

7 files changed

+173
-112
lines changed

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ package org.apache.spark.util
2020
import java.io._
2121
import java.net._
2222
import java.nio.ByteBuffer
23+
import java.util.jar.Attributes.Name
2324
import java.util.{Properties, Locale, Random, UUID}
2425
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
26+
import java.util.jar.{Manifest => JarManifest}
2527

2628
import scala.collection.JavaConversions._
2729
import scala.collection.Map
@@ -1759,6 +1761,12 @@ private[spark] object Utils extends Logging {
17591761
s"$libraryPathEnvName=$libraryPath$ampersand"
17601762
}
17611763

1764+
lazy val sparkVersion =
1765+
SparkContext.jarOfObject(this).map { path =>
1766+
val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF")
1767+
val manifest = new JarManifest(manifestUrl.openStream())
1768+
manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION)
1769+
}.getOrElse("Unknown")
17621770
}
17631771

17641772
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,50 +84,35 @@ case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribut
8484
extends LeafNode with Command with Logging {
8585

8686
override protected lazy val sideEffectResult: Seq[Row] = kv match {
87-
// Set value for the key.
88-
case Some((key, Some(value))) =>
89-
if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
90-
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
87+
// Configures the deprecated "mapred.reduce.tasks" property.
88+
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) =>
89+
logWarning(
90+
s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
9191
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
92-
context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
93-
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
94-
} else {
95-
context.setConf(key, value)
96-
Seq(Row(s"$key=$value"))
97-
}
98-
99-
// Query the value bound to the key.
92+
context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
93+
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
94+
95+
// Configures a single property.
96+
case Some((key, Some(value))) =>
97+
context.setConf(key, value)
98+
Seq(Row(s"$key=$value"))
99+
100+
// Queries all key-value pairs that are set in the SQLConf of the context. Notice that different
101+
// from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties
102+
// while "SET -v" returns all properties.)
103+
case Some(("-v", None)) | None =>
104+
context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq
105+
106+
// Queries the deprecated "mapred.reduce.tasks" property.
107+
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) =>
108+
logWarning(
109+
s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
110+
s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
111+
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
112+
113+
// Queries a single property.
100114
case Some((key, None)) =>
101-
// TODO (lian) This is just a workaround to make the Simba ODBC driver work.
102-
// Should remove this once we get the ODBC driver updated.
103-
if (key == "-v") {
104-
val hiveJars = Seq(
105-
"hive-exec-0.12.0.jar",
106-
"hive-service-0.12.0.jar",
107-
"hive-common-0.12.0.jar",
108-
"hive-hwi-0.12.0.jar",
109-
"hive-0.12.0.jar").mkString(":")
110-
111-
context.getAllConfs.map { case (k, v) =>
112-
Row(s"$k=$v")
113-
}.toSeq ++ Seq(
114-
Row("system:java.class.path=" + hiveJars),
115-
Row("system:sun.java.command=shark.SharkServer2"))
116-
} else {
117-
if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
118-
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
119-
s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
120-
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
121-
} else {
122-
Seq(Row(s"$key=${context.getConf(key, "<undefined>")}"))
123-
}
124-
}
125-
126-
// Query all key-value pairs that are set in the SQLConf of the context.
127-
case _ =>
128-
context.getAllConfs.map { case (k, v) =>
129-
Row(s"$k=$v")
130-
}.toSeq
115+
Seq(Row(s"$key=${context.getConf(key, "<undefined>")}"))
131116
}
132117

133118
override def otherCopyArgs = context :: Nil

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.hive.thriftserver
1919

20+
import java.util.jar.Attributes.Name
21+
2022
import scala.collection.JavaConversions._
2123

2224
import java.io.IOException
@@ -29,11 +31,12 @@ import org.apache.hadoop.hive.conf.HiveConf
2931
import org.apache.hadoop.hive.shims.ShimLoader
3032
import org.apache.hive.service.Service.STATE
3133
import org.apache.hive.service.auth.HiveAuthFactory
32-
import org.apache.hive.service.cli.CLIService
34+
import org.apache.hive.service.cli._
3335
import org.apache.hive.service.{AbstractService, Service, ServiceException}
3436

3537
import org.apache.spark.sql.hive.HiveContext
3638
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
39+
import org.apache.spark.util.Utils
3740

3841
private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
3942
extends CLIService
@@ -60,6 +63,15 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
6063

6164
initCompositeService(hiveConf)
6265
}
66+
67+
override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
68+
getInfoType match {
69+
case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
70+
case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
71+
case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion)
72+
case _ => super.getInfo(sessionHandle, getInfoType)
73+
}
74+
}
6375
}
6476

6577
private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.sql.hive.thriftserver
1919

20+
import scala.collection.JavaConversions._
21+
2022
import org.apache.spark.scheduler.StatsReportListener
21-
import org.apache.spark.sql.hive.HiveContext
23+
import org.apache.spark.sql.hive.{HiveShim, HiveContext}
2224
import org.apache.spark.{Logging, SparkConf, SparkContext}
23-
import scala.collection.JavaConversions._
2425

2526
/** A singleton object for the master program. The slaves should not access this. */
2627
private[hive] object SparkSQLEnv extends Logging {
@@ -31,8 +32,10 @@ private[hive] object SparkSQLEnv extends Logging {
3132

3233
def init() {
3334
if (hiveContext == null) {
34-
sparkContext = new SparkContext(new SparkConf()
35-
.setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
35+
val sparkConf = new SparkConf()
36+
.setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")
37+
.set("spark.sql.hive.version", HiveShim.version)
38+
sparkContext = new SparkContext(sparkConf)
3639

3740
sparkContext.addSparkListener(new StatsReportListener())
3841
hiveContext = new HiveContext(sparkContext)

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala

Lines changed: 110 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,95 @@ import scala.util.Try
3030

3131
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
3232
import org.apache.hive.jdbc.HiveDriver
33+
import org.apache.hive.service.auth.PlainSaslHelper
34+
import org.apache.hive.service.cli.GetInfoType
35+
import org.apache.hive.service.cli.thrift.TCLIService.Client
36+
import org.apache.hive.service.cli.thrift._
37+
import org.apache.thrift.protocol.TBinaryProtocol
38+
import org.apache.thrift.transport.TSocket
3339
import org.scalatest.FunSuite
3440

3541
import org.apache.spark.Logging
3642
import org.apache.spark.sql.catalyst.util.getTempFilePath
43+
import org.apache.spark.sql.hive.HiveShim
3744

3845
/**
3946
* Tests for the HiveThriftServer2 using JDBC.
47+
*
48+
* NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
49+
* rebuilt after changing HiveThriftServer2 related code.
4050
*/
4151
class HiveThriftServer2Suite extends FunSuite with Logging {
4252
Class.forName(classOf[HiveDriver].getCanonicalName)
4353

44-
def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
54+
def randomListeningPort = {
55+
// Let the system to choose a random available port to avoid collision with other parallel
56+
// builds.
57+
val socket = new ServerSocket(0)
58+
val port = socket.getLocalPort
59+
socket.close()
60+
port
61+
}
62+
63+
def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
64+
val port = randomListeningPort
65+
66+
startThriftServer(port, serverStartTimeout) {
67+
val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/"
68+
val user = System.getProperty("user.name")
69+
val connection = DriverManager.getConnection(jdbcUri, user, "")
70+
val statement = connection.createStatement()
71+
72+
try {
73+
f(statement)
74+
} finally {
75+
statement.close()
76+
connection.close()
77+
}
78+
}
79+
}
80+
81+
def withCLIServiceClient(
82+
serverStartTimeout: FiniteDuration = 1.minute)(
83+
f: ThriftCLIServiceClient => Unit) {
84+
val port = randomListeningPort
85+
86+
startThriftServer(port) {
87+
// Transport creation logics below mimics HiveConnection.createBinaryTransport
88+
val rawTransport = new TSocket("localhost", port)
89+
val user = System.getProperty("user.name")
90+
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
91+
val protocol = new TBinaryProtocol(transport)
92+
val client = new ThriftCLIServiceClient(new Client(protocol))
93+
94+
transport.open()
95+
96+
try {
97+
f(client)
98+
} finally {
99+
transport.close()
100+
}
101+
}
102+
}
103+
104+
def startThriftServer(
105+
port: Int,
106+
serverStartTimeout: FiniteDuration = 1.minute)(
107+
f: => Unit) {
45108
val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
46109
val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
47110

48111
val warehousePath = getTempFilePath("warehouse")
49112
val metastorePath = getTempFilePath("metastore")
50113
val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
51-
val listeningHost = "localhost"
52-
val listeningPort = {
53-
// Let the system to choose a random available port to avoid collision with other parallel
54-
// builds.
55-
val socket = new ServerSocket(0)
56-
val port = socket.getLocalPort
57-
socket.close()
58-
port
59-
}
60-
61114
val command =
62115
s"""$startScript
63116
| --master local
64117
| --hiveconf hive.root.logger=INFO,console
65118
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
66119
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
67-
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
68-
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
120+
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"}
121+
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
69122
""".stripMargin.split("\\s+").toSeq
70123

71124
val serverRunning = Promise[Unit]()
@@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
92145
}
93146
}
94147

95-
// Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
96-
Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger(
148+
val env = Seq(
149+
// Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
150+
"SPARK_TESTING" -> "0",
151+
// Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read
152+
// proper version information from the jar manifest.
153+
"SPARK_PREPEND_CLASSES" -> "")
154+
155+
Process(command, None, env: _*).run(ProcessLogger(
97156
captureThriftServerOutput("stdout"),
98157
captureThriftServerOutput("stderr")))
99158

100-
val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
101-
val user = System.getProperty("user.name")
102-
103159
try {
104-
Await.result(serverRunning.future, timeout)
105-
106-
val connection = DriverManager.getConnection(jdbcUri, user, "")
107-
val statement = connection.createStatement()
108-
109-
try {
110-
f(statement)
111-
} finally {
112-
statement.close()
113-
connection.close()
114-
}
160+
Await.result(serverRunning.future, serverStartTimeout)
161+
f
115162
} catch {
116163
case cause: Exception =>
117164
cause match {
118165
case _: TimeoutException =>
119-
logError(s"Failed to start Hive Thrift server within $timeout", cause)
166+
logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
120167
case _ =>
121168
}
122169
logError(
@@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
125172
|HiveThriftServer2Suite failure output
126173
|=====================================
127174
|HiveThriftServer2 command line: ${command.mkString(" ")}
128-
|JDBC URI: $jdbcUri
129-
|User: $user
175+
|Binding port: $port
176+
|System user: ${System.getProperty("user.name")}
130177
|
131178
|${buffer.mkString("\n")}
132179
|=========================================
@@ -146,7 +193,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
146193
}
147194

148195
test("Test JDBC query execution") {
149-
startThriftServerWithin() { statement =>
196+
withJdbcStatement() { statement =>
150197
val dataFilePath =
151198
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
152199

@@ -168,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
168215
}
169216

170217
test("SPARK-3004 regression: result set containing NULL") {
171-
startThriftServerWithin() { statement =>
218+
withJdbcStatement() { statement =>
172219
val dataFilePath =
173220
Thread.currentThread().getContextClassLoader.getResource(
174221
"data/files/small_kv_with_null.txt")
@@ -191,4 +238,33 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
191238
assert(!resultSet.next())
192239
}
193240
}
241+
242+
test("GetInfo Thrift API") {
243+
withCLIServiceClient() { client =>
244+
val user = System.getProperty("user.name")
245+
val sessionHandle = client.openSession(user, "")
246+
247+
assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
248+
client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
249+
}
250+
251+
assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
252+
client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
253+
}
254+
255+
assertResult(true, "Spark version shouldn't be \"Unknown\"") {
256+
val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
257+
logInfo(s"Spark version: $version")
258+
version != "Unknown"
259+
}
260+
}
261+
}
262+
263+
test("Checks Hive version") {
264+
withJdbcStatement() { statement =>
265+
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
266+
resultSet.next()
267+
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
268+
}
269+
}
194270
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
323323
driver.close()
324324
HiveShim.processResults(results)
325325
case _ =>
326-
sessionState.out.println(tokens(0) + " " + cmd_1)
326+
if (sessionState.out != null) {
327+
sessionState.out.println(tokens(0) + " " + cmd_1)
328+
}
327329
Seq(proc.run(cmd_1).getResponseCode.toString)
328330
}
329331
} catch {

0 commit comments

Comments
 (0)