Skip to content

Commit 0f56aad

Browse files
lianchengmarmbrus
authored andcommitted
[SPARK-1368][SQL] Optimized HiveTableScan
JIRA issue: [SPARK-1368](https://issues.apache.org/jira/browse/SPARK-1368) This PR introduces two major updates: - Replaced FP style code with `while` loop and reusable `GenericMutableRow` object in critical path of `HiveTableScan`. - Using `ColumnProjectionUtils` to help optimizing RCFile and ORC column pruning. My quick micro benchmark suggests these two optimizations made the optimized version 2x and 2.5x faster when scanning CSV table and RCFile table respectively: ``` Original: [info] CSV: 27676 ms, RCFile: 26415 ms [info] CSV: 27703 ms, RCFile: 26029 ms [info] CSV: 27511 ms, RCFile: 25962 ms Optimized: [info] CSV: 13820 ms, RCFile: 10402 ms [info] CSV: 14158 ms, RCFile: 10691 ms [info] CSV: 13606 ms, RCFile: 10346 ms ``` The micro benchmark loads a 609MB CVS file (structurally similar to the `src` test table) into a normal Hive table with `LazySimpleSerDe` and a RCFile table, then scans these tables respectively. Preparation code: ```scala package org.apache.spark.examples.sql.hive import org.apache.spark.sql.hive.LocalHiveContext import org.apache.spark.{SparkConf, SparkContext} object HiveTableScanPrepare extends App { val sparkContext = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new LocalHiveContext(sparkContext) import hiveContext._ hql("drop table scan_csv") hql("drop table scan_rcfile") hql("""create table scan_csv (key int, value string) | row format serde 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' | with serdeproperties ('field.delim'=',') """.stripMargin) hql(s"""load data local inpath "${args(0)}" into table scan_csv""") hql("""create table scan_rcfile (key int, value string) | row format serde 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' |stored as | inputformat 'org.apache.hadoop.hive.ql.io.RCFileInputFormat' | outputformat 'org.apache.hadoop.hive.ql.io.RCFileOutputFormat' """.stripMargin) hql( """ |from scan_csv |insert overwrite table scan_rcfile |select scan_csv.key, scan_csv.value """.stripMargin) } ``` Benchmark code: ```scala package org.apache.spark.examples.sql.hive import org.apache.spark.sql.hive.LocalHiveContext import org.apache.spark.{SparkConf, SparkContext} object HiveTableScanBenchmark extends App { val sparkContext = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new LocalHiveContext(sparkContext) import hiveContext._ val scanCsv = hql("select key from scan_csv") val scanRcfile = hql("select key from scan_rcfile") val csvDuration = benchmark(scanCsv.count()) val rcfileDuration = benchmark(scanRcfile.count()) println(s"CSV: $csvDuration ms, RCFile: $rcfileDuration ms") def benchmark(f: => Unit) = { val begin = System.currentTimeMillis() f val end = System.currentTimeMillis() end - begin } } ``` @marmbrus Please help review, thanks! Author: Cheng Lian <[email protected]> Closes #758 from liancheng/fastHiveTableScan and squashes the following commits: 4241a19 [Cheng Lian] Distinguishes sorted and possibly not sorted operations more accurately in HiveComparisonTest cf640d8 [Cheng Lian] More HiveTableScan optimisations: bf0e7dc [Cheng Lian] Added SortedOperation pattern to match *some* definitely sorted operations and avoid some sorting cost in HiveComparisonTest. 6d1c642 [Cheng Lian] Using ColumnProjectionUtils to optimise RCFile and ORC column pruning eb62fd3 [Cheng Lian] [SPARK-1368] Optimized HiveTableScan (cherry picked from commit 8f7141f) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 8bb9390 commit 0f56aad

File tree

3 files changed

+96
-28
lines changed

3 files changed

+96
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ case class Aggregate(
116116
*/
117117
@transient
118118
private[this] lazy val resultMap =
119-
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
119+
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap
120120

121121
/**
122122
* Substituted version of aggregateExpressions expressions which are used to compute final

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
package org.apache.spark.sql.hive.execution
1919

2020
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
21+
import org.apache.hadoop.hive.conf.HiveConf
2122
import org.apache.hadoop.hive.metastore.MetaStoreUtils
2223
import org.apache.hadoop.hive.ql.Context
2324
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
2425
import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
25-
import org.apache.hadoop.hive.serde2.Serializer
26+
import org.apache.hadoop.hive.serde.serdeConstants
2627
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
2728
import org.apache.hadoop.hive.serde2.objectinspector._
2829
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
2930
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
31+
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
32+
import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer}
3033
import org.apache.hadoop.io.Writable
3134
import org.apache.hadoop.mapred._
3235

@@ -37,6 +40,7 @@ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
3740
import org.apache.spark.sql.execution._
3841
import org.apache.spark.sql.hive._
3942
import org.apache.spark.{TaskContext, SparkException}
43+
import org.apache.spark.util.MutablePair
4044

4145
/* Implicits */
4246
import scala.collection.JavaConversions._
@@ -94,24 +98,63 @@ case class HiveTableScan(
9498
(_: Any, partitionKeys: Array[String]) => {
9599
val value = partitionKeys(ordinal)
96100
val dataType = relation.partitionKeys(ordinal).dataType
97-
castFromString(value, dataType)
101+
unwrapHiveData(castFromString(value, dataType))
98102
}
99103
} else {
100104
val ref = objectInspector.getAllStructFieldRefs
101105
.find(_.getFieldName == a.name)
102106
.getOrElse(sys.error(s"Can't find attribute $a"))
103107
(row: Any, _: Array[String]) => {
104108
val data = objectInspector.getStructFieldData(row, ref)
105-
unwrapData(data, ref.getFieldObjectInspector)
109+
unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
106110
}
107111
}
108112
}
109113
}
110114

115+
private def unwrapHiveData(value: Any) = value match {
116+
case maybeNull: String if maybeNull.toLowerCase == "null" => null
117+
case varchar: HiveVarchar => varchar.getValue
118+
case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue)
119+
case other => other
120+
}
121+
111122
private def castFromString(value: String, dataType: DataType) = {
112123
Cast(Literal(value), dataType).eval(null)
113124
}
114125

126+
private def addColumnMetadataToConf(hiveConf: HiveConf) {
127+
// Specifies IDs and internal names of columns to be scanned.
128+
val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer)
129+
val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")
130+
131+
if (attributes.size == relation.output.size) {
132+
ColumnProjectionUtils.setFullyReadColumns(hiveConf)
133+
} else {
134+
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
135+
}
136+
137+
ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name))
138+
139+
// Specifies types and object inspectors of columns to be scanned.
140+
val structOI = ObjectInspectorUtils
141+
.getStandardObjectInspector(
142+
relation.tableDesc.getDeserializer.getObjectInspector,
143+
ObjectInspectorCopyOption.JAVA)
144+
.asInstanceOf[StructObjectInspector]
145+
146+
val columnTypeNames = structOI
147+
.getAllStructFieldRefs
148+
.map(_.getFieldObjectInspector)
149+
.map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName)
150+
.mkString(",")
151+
152+
hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames)
153+
hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames)
154+
}
155+
156+
addColumnMetadataToConf(sc.hiveconf)
157+
115158
@transient
116159
def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
117160
hadoopReader.makeRDDForTable(relation.hiveQlTable)
@@ -143,20 +186,42 @@ case class HiveTableScan(
143186
}
144187

145188
def execute() = {
146-
inputRdd.map { row =>
147-
val values = row match {
148-
case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
149-
attributeFunctions.map(_(deserializedRow, partitionKeys))
150-
case deserializedRow: AnyRef =>
151-
attributeFunctions.map(_(deserializedRow, Array.empty))
189+
inputRdd.mapPartitions { iterator =>
190+
if (iterator.isEmpty) {
191+
Iterator.empty
192+
} else {
193+
val mutableRow = new GenericMutableRow(attributes.length)
194+
val mutablePair = new MutablePair[Any, Array[String]]()
195+
val buffered = iterator.buffered
196+
197+
// NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
198+
// matching are avoided intentionally.
199+
val rowsAndPartitionKeys = buffered.head match {
200+
// With partition keys
201+
case _: Array[Any] =>
202+
buffered.map { case array: Array[Any] =>
203+
val deserializedRow = array(0)
204+
val partitionKeys = array(1).asInstanceOf[Array[String]]
205+
mutablePair.update(deserializedRow, partitionKeys)
206+
}
207+
208+
// Without partition keys
209+
case _ =>
210+
val emptyPartitionKeys = Array.empty[String]
211+
buffered.map { deserializedRow =>
212+
mutablePair.update(deserializedRow, emptyPartitionKeys)
213+
}
214+
}
215+
216+
rowsAndPartitionKeys.map { pair =>
217+
var i = 0
218+
while (i < attributes.length) {
219+
mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
220+
i += 1
221+
}
222+
mutableRow: Row
223+
}
152224
}
153-
buildRow(values.map {
154-
case n: String if n.toLowerCase == "null" => null
155-
case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
156-
case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
157-
BigDecimal(decimal.bigDecimalValue)
158-
case other => other
159-
})
160225
}
161226
}
162227

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ package org.apache.spark.sql.hive.execution
1919

2020
import java.io._
2121

22+
import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
23+
2224
import org.apache.spark.sql.Logging
23-
import org.apache.spark.sql.catalyst.plans.logical.{ExplainCommand, NativeCommand}
25+
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
26+
import org.apache.spark.sql.catalyst.plans.logical._
2427
import org.apache.spark.sql.catalyst.util._
25-
import org.apache.spark.sql.execution.Sort
26-
import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
2728
import org.apache.spark.sql.hive.test.TestHive
2829

2930
/**
@@ -128,17 +129,19 @@ abstract class HiveComparisonTest
128129
protected def prepareAnswer(
129130
hiveQuery: TestHive.type#HiveQLQueryExecution,
130131
answer: Seq[String]): Seq[String] = {
132+
133+
def isSorted(plan: LogicalPlan): Boolean = plan match {
134+
case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false
135+
case PhysicalOperation(_, _, Sort(_, _)) => true
136+
case _ => plan.children.iterator.map(isSorted).exists(_ == true)
137+
}
138+
131139
val orderedAnswer = hiveQuery.logical match {
132140
// Clean out non-deterministic time schema info.
133141
case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
134142
case _: ExplainCommand => answer
135-
case _ =>
136-
// TODO: Really we only care about the final total ordering here...
137-
val isOrdered = hiveQuery.executedPlan.collect {
138-
case s @ Sort(_, global, _) if global => s
139-
}.nonEmpty
140-
// If the query results aren't sorted, then sort them to ensure deterministic answers.
141-
if (!isOrdered) answer.sorted else answer
143+
case plan if isSorted(plan) => answer
144+
case _ => answer.sorted
142145
}
143146
orderedAnswer.map(cleanPaths)
144147
}
@@ -161,7 +164,7 @@ abstract class HiveComparisonTest
161164
"minFileSize"
162165
)
163166
protected def nonDeterministicLine(line: String) =
164-
nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_)
167+
nonDeterministicLineIndicators.exists(line contains _)
165168

166169
/**
167170
* Removes non-deterministic paths from `str` so cached answers will compare correctly.

0 commit comments

Comments
 (0)