Skip to content

Commit a21c8b5

Browse files
committed
Merge branch 'master' into SPARK-8103
2 parents 906d626 + c4e98ff commit a21c8b5

File tree

4 files changed

+141
-50
lines changed

4 files changed

+141
-50
lines changed

build/mvn

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,17 @@ install_scala() {
112112
# the environment
113113
ZINC_PORT=${ZINC_PORT:-"3030"}
114114

115+
# Check for the `--force` flag dictating that `mvn` should be downloaded
116+
# regardless of whether the system already has a `mvn` install
117+
if [ "$1" == "--force" ]; then
118+
FORCE_MVN=1
119+
shift
120+
fi
121+
115122
# Install Maven if necessary
116123
MVN_BIN="$(command -v mvn)"
117124

118-
if [ ! "$MVN_BIN" ]; then
125+
if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then
119126
install_mvn
120127
fi
121128

@@ -139,5 +146,7 @@ fi
139146
# Set any `mvn` options if not already present
140147
export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
141148

149+
echo "Using \`mvn\` from path: $MVN_BIN"
150+
142151
# Last, call the `mvn` command as usual
143152
${MVN_BIN} "$@"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
230230
}
231231
}
232232

233+
private def evalElse(input: InternalRow): Any = {
234+
if (branchesArr.length % 2 == 0) {
235+
null
236+
} else {
237+
branchesArr(branchesArr.length - 1).eval(input)
238+
}
239+
}
240+
233241
/** Written in imperative fashion for performance considerations. */
234242
override def eval(input: InternalRow): Any = {
235243
val evaluatedKey = key.eval(input)
236-
val len = branchesArr.length
237-
var i = 0
238-
// If all branches fail and an elseVal is not provided, the whole statement
239-
// defaults to null, according to Hive's semantics.
240-
while (i < len - 1) {
241-
if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
242-
return branchesArr(i + 1).eval(input)
244+
// If key is null, we can just return the else part or null if there is no else.
245+
// If key is not null but doesn't match any when part, we need to return
246+
// the else part or null if there is no else, according to Hive's semantics.
247+
if (evaluatedKey != null) {
248+
val len = branchesArr.length
249+
var i = 0
250+
while (i < len - 1) {
251+
if (evaluatedKey == branchesArr(i).eval(input)) {
252+
return branchesArr(i + 1).eval(input)
253+
}
254+
i += 2
243255
}
244-
i += 2
245256
}
246-
var res: Any = null
247-
if (i == len - 1) {
248-
res = branchesArr(i).eval(input)
249-
}
250-
return res
257+
evalElse(input)
251258
}
252259

253260
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
261268
s"""
262269
if (!$got) {
263270
${cond.code}
264-
if (!${keyEval.isNull} && !${cond.isNull}
265-
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
271+
if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
266272
$got = true;
267273
${res.code}
268274
${ev.isNull} = ${res.isNull};
@@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
290296
boolean ${ev.isNull} = true;
291297
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
292298
${keyEval.code}
293-
$cases
299+
if (!${keyEval.isNull}) {
300+
$cases
301+
}
294302
$other
295303
"""
296304
}
297305

298-
private def threeValueEquals(l: Any, r: Any) = {
299-
if (l == null || r == null) {
300-
false
301-
} else {
302-
l == r
303-
}
304-
}
305-
306306
override def toString: String = {
307307
s"CASE $key" + branches.sliding(2, 2).map {
308308
case Seq(cond, value) => s" WHEN $cond THEN $value"

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
3434
import org.apache.hadoop.hive.serde.serdeConstants
3535

3636
import org.apache.spark.Logging
37-
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
37+
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.types.{StringType, IntegralType}
3939

4040
/**
@@ -312,37 +312,41 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
312312
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
313313
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
314314

315-
override def getPartitionsByFilter(
316-
hive: Hive,
317-
table: Table,
318-
predicates: Seq[Expression]): Seq[Partition] = {
315+
/**
316+
* Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e.
317+
* a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...".
318+
*
319+
* Unsupported predicates are skipped.
320+
*/
321+
def convertFilters(table: Table, filters: Seq[Expression]): String = {
319322
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
320323
val varcharKeys = table.getPartitionKeys
321324
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
322325
.map(col => col.getName).toSet
323326

324-
// Hive getPartitionsByFilter() takes a string that represents partition
325-
// predicates like "str_key=\"value\" and int_key=1 ..."
326-
val filter = predicates.flatMap { expr =>
327-
expr match {
328-
case op @ BinaryComparison(lhs, rhs) => {
329-
lhs match {
330-
case AttributeReference(_, _, _, _) => {
331-
rhs.dataType match {
332-
case _: IntegralType =>
333-
Some(lhs.prettyString + op.symbol + rhs.prettyString)
334-
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
335-
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
336-
case _ => None
337-
}
338-
}
339-
case _ => None
340-
}
341-
}
342-
case _ => None
343-
}
327+
filters.collect {
328+
case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
329+
s"${a.name} ${op.symbol} $v"
330+
case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
331+
s"$v ${op.symbol} ${a.name}"
332+
333+
case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType))
334+
if !varcharKeys.contains(a.name) =>
335+
s"""${a.name} ${op.symbol} "$v""""
336+
case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute)
337+
if !varcharKeys.contains(a.name) =>
338+
s""""$v" ${op.symbol} ${a.name}"""
344339
}.mkString(" and ")
340+
}
341+
342+
override def getPartitionsByFilter(
343+
hive: Hive,
344+
table: Table,
345+
predicates: Seq[Expression]): Seq[Partition] = {
345346

347+
// Hive getPartitionsByFilter() takes a string that represents partition
348+
// predicates like "str_key=\"value\" and int_key=1 ..."
349+
val filter = convertFilters(table, predicates)
346350
val partitions =
347351
if (filter.isEmpty) {
348352
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.client
19+
20+
import scala.collection.JavaConversions._
21+
22+
import org.apache.hadoop.hive.metastore.api.FieldSchema
23+
import org.apache.hadoop.hive.serde.serdeConstants
24+
25+
import org.apache.spark.{Logging, SparkFunSuite}
26+
import org.apache.spark.sql.catalyst.dsl.expressions._
27+
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.types._
29+
30+
/**
31+
* A set of tests for the filter conversion logic used when pushing partition pruning into the
32+
* metastore
33+
*/
34+
class FiltersSuite extends SparkFunSuite with Logging {
35+
private val shim = new Shim_v0_13
36+
37+
private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
38+
private val varCharCol = new FieldSchema()
39+
varCharCol.setName("varchar")
40+
varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
41+
testTable.setPartCols(varCharCol :: Nil)
42+
43+
filterTest("string filter",
44+
(a("stringcol", StringType) > Literal("test")) :: Nil,
45+
"stringcol > \"test\"")
46+
47+
filterTest("string filter backwards",
48+
(Literal("test") > a("stringcol", StringType)) :: Nil,
49+
"\"test\" > stringcol")
50+
51+
filterTest("int filter",
52+
(a("intcol", IntegerType) === Literal(1)) :: Nil,
53+
"intcol = 1")
54+
55+
filterTest("int filter backwards",
56+
(Literal(1) === a("intcol", IntegerType)) :: Nil,
57+
"1 = intcol")
58+
59+
filterTest("int and string filter",
60+
(Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
61+
"1 = intcol and \"a\" = strcol")
62+
63+
filterTest("skip varchar",
64+
(Literal("") === a("varchar", StringType)) :: Nil,
65+
"")
66+
67+
private def filterTest(name: String, filters: Seq[Expression], result: String) = {
68+
test(name){
69+
val converted = shim.convertFilters(testTable, filters)
70+
if (converted != result) {
71+
fail(
72+
s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
73+
}
74+
}
75+
}
76+
77+
private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
78+
}

0 commit comments

Comments
 (0)