Skip to content

Commit 9e12c94

Browse files
peter-tothwangyum
authored andcommitted
[SPARK-29359][SQL][TESTS] Better exception handling in (SQL|ThriftServer)QueryTestSuite
### What changes were proposed in this pull request? This PR adds 2 changes regarding exception handling in `SQLQueryTestSuite` and `ThriftServerQueryTestSuite` - fixes an expected output sorting issue in `ThriftServerQueryTestSuite` as if there is an exception then there is no need for sort - introduces common exception handling in those 2 suites with a new `handleExceptions` method ### Why are the changes needed? Currently `ThriftServerQueryTestSuite` passes on master, but it fails on one of my PRs (apache#23531) with this error (https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/111651/testReport/org.apache.spark.sql.hive.thriftserver/ThriftServerQueryTestSuite/sql_3/): ``` org.scalatest.exceptions.TestFailedException: Expected " [Recursion level limit 100 reached but query has not exhausted, try increasing spark.sql.cte.recursion.level.limit org.apache.spark.SparkException] ", but got " [org.apache.spark.SparkException Recursion level limit 100 reached but query has not exhausted, try increasing spark.sql.cte.recursion.level.limit] " Result did not match for query #4 WITH RECURSIVE r(level) AS ( VALUES (0) UNION ALL SELECT level + 1 FROM r ) SELECT * FROM r ``` The unexpected reversed order of expected output (error message comes first, then the exception class) is due to this line: https://github.com/apache/spark/pull/26028/files#diff-b3ea3021602a88056e52bf83d8782de8L146. It should not sort the expected output if there was an error during execution. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UTs. Closes apache#26028 from peter-toth/SPARK-29359-better-exception-handling. Authored-by: Peter Toth <[email protected]> Signed-off-by: Yuming Wang <[email protected]>
1 parent abba53e commit 9e12c94

File tree

2 files changed

+66
-56
lines changed

2 files changed

+66
-56
lines changed

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
135135
private val notIncludedMsg = "[not included in comparison]"
136136
private val clsName = this.getClass.getCanonicalName
137137

138+
protected val emptySchema = StructType(Seq.empty).catalogString
139+
138140
protected override def sparkConf: SparkConf = super.sparkConf
139141
// Fewer shuffle partitions to speed up testing.
140142
.set(SQLConf.SHUFFLE_PARTITIONS, 4)
@@ -323,11 +325,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
323325
}
324326
// Run the SQL queries preparing them for comparison.
325327
val outputs: Seq[QueryOutput] = queries.map { sql =>
326-
val (schema, output) = getNormalizedResult(localSparkSession, sql)
328+
val (schema, output) = handleExceptions(getNormalizedResult(localSparkSession, sql))
327329
// We might need to do some query canonicalization in the future.
328330
QueryOutput(
329331
sql = sql,
330-
schema = schema.catalogString,
332+
schema = schema,
331333
output = output.mkString("\n").replaceAll("\\s+$", ""))
332334
}
333335

@@ -388,49 +390,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
388390
}
389391
}
390392

391-
/** Executes a query and returns the result as (schema of the output, normalized output). */
392-
private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = {
393-
// Returns true if the plan is supposed to be sorted.
394-
def isSorted(plan: LogicalPlan): Boolean = plan match {
395-
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
396-
case _: DescribeCommandBase
397-
| _: DescribeColumnCommand
398-
| _: DescribeTableStatement
399-
| _: DescribeColumnStatement => true
400-
case PhysicalOperation(_, _, Sort(_, true, _)) => true
401-
case _ => plan.children.iterator.exists(isSorted)
402-
}
403-
393+
/**
394+
* This method handles exceptions occurred during query execution as they may need special care
395+
* to become comparable to the expected output.
396+
*
397+
* @param result a function that returns a pair of schema and output
398+
*/
399+
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
404400
try {
405-
val df = session.sql(sql)
406-
val schema = df.schema
407-
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
408-
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
409-
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
410-
}
411-
412-
// If the output is not pre-sorted, sort it.
413-
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
414-
401+
result
415402
} catch {
416403
case a: AnalysisException =>
417404
// Do not output the logical plan tree which contains expression IDs.
418405
// Also implement a crude way of masking expression IDs in the error message
419406
// with a generic pattern "###".
420407
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
421-
(StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
408+
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
422409
case s: SparkException if s.getCause != null =>
423410
// For a runtime exception, it is hard to match because its message contains
424411
// information of stage, task ID, etc.
425412
// To make result matching simpler, here we match the cause of the exception if it exists.
426413
val cause = s.getCause
427-
(StructType(Seq.empty), Seq(cause.getClass.getName, cause.getMessage))
414+
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
428415
case NonFatal(e) =>
429416
// If there is an exception, put the exception class followed by the message.
430-
(StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))
417+
(emptySchema, Seq(e.getClass.getName, e.getMessage))
431418
}
432419
}
433420

421+
/** Executes a query and returns the result as (schema of the output, normalized output). */
422+
private def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
423+
// Returns true if the plan is supposed to be sorted.
424+
def isSorted(plan: LogicalPlan): Boolean = plan match {
425+
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
426+
case _: DescribeCommandBase
427+
| _: DescribeColumnCommand
428+
| _: DescribeTableStatement
429+
| _: DescribeColumnStatement => true
430+
case PhysicalOperation(_, _, Sort(_, true, _)) => true
431+
case _ => plan.children.iterator.exists(isSorted)
432+
}
433+
434+
val df = session.sql(sql)
435+
val schema = df.schema.catalogString
436+
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
437+
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
438+
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
439+
}
440+
441+
// If the output is not pre-sorted, sort it.
442+
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
443+
}
444+
434445
protected def replaceNotIncludedMsg(line: String): String = {
435446
line.replaceAll("#\\d+", "#x")
436447
.replaceAll(

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
2828
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
2929

3030
import org.apache.spark.{SparkConf, SparkException}
31-
import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite}
31+
import org.apache.spark.sql.SQLQueryTestSuite
3232
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
3333
import org.apache.spark.sql.catalyst.util.fileToString
3434
import org.apache.spark.sql.execution.HiveResult
@@ -123,7 +123,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
123123

124124
// Run the SQL queries preparing them for comparison.
125125
val outputs: Seq[QueryOutput] = queries.map { sql =>
126-
val output = getNormalizedResult(statement, sql)
126+
val (_, output) = handleExceptions(getNormalizedResult(statement, sql))
127127
// We might need to do some query canonicalization in the future.
128128
QueryOutput(
129129
sql = sql,
@@ -142,8 +142,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
142142
"Try regenerate the result files.")
143143
Seq.tabulate(outputs.size) { i =>
144144
val sql = segments(i * 3 + 1).trim
145+
val schema = segments(i * 3 + 2).trim
145146
val originalOut = segments(i * 3 + 3)
146-
val output = if (isNeedSort(sql)) {
147+
val output = if (schema != emptySchema && isNeedSort(sql)) {
147148
originalOut.split("\n").sorted.mkString("\n")
148149
} else {
149150
originalOut
@@ -254,32 +255,30 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
254255
}
255256
}
256257

257-
private def getNormalizedResult(statement: Statement, sql: String): Seq[String] = {
258-
try {
259-
val rs = statement.executeQuery(sql)
260-
val cols = rs.getMetaData.getColumnCount
261-
val buildStr = () => (for (i <- 1 to cols) yield {
262-
getHiveResult(rs.getObject(i))
263-
}).mkString("\t")
264-
265-
val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
266-
.map(replaceNotIncludedMsg)
267-
if (isNeedSort(sql)) {
268-
answer.sorted
269-
} else {
270-
answer
258+
/** ThriftServer wraps the root exception, so it needs to be extracted. */
259+
override def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
260+
super.handleExceptions {
261+
try {
262+
result
263+
} catch {
264+
case NonFatal(e) => throw ExceptionUtils.getRootCause(e)
271265
}
272-
} catch {
273-
case a: AnalysisException =>
274-
// Do not output the logical plan tree which contains expression IDs.
275-
// Also implement a crude way of masking expression IDs in the error message
276-
// with a generic pattern "###".
277-
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
278-
Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted
279-
case NonFatal(e) =>
280-
val rootCause = ExceptionUtils.getRootCause(e)
281-
// If there is an exception, put the exception class followed by the message.
282-
Seq(rootCause.getClass.getName, rootCause.getMessage)
266+
}
267+
}
268+
269+
private def getNormalizedResult(statement: Statement, sql: String): (String, Seq[String]) = {
270+
val rs = statement.executeQuery(sql)
271+
val cols = rs.getMetaData.getColumnCount
272+
val buildStr = () => (for (i <- 1 to cols) yield {
273+
getHiveResult(rs.getObject(i))
274+
}).mkString("\t")
275+
276+
val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
277+
.map(replaceNotIncludedMsg)
278+
if (isNeedSort(sql)) {
279+
("", answer.sorted)
280+
} else {
281+
("", answer)
283282
}
284283
}
285284

0 commit comments

Comments
 (0)