Skip to content

Commit eb62fd3

Browse files
committed
[SPARK-1368] Optimized HiveTableScan
1 parent 82eadc3 commit eb62fd3

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,27 @@ case class HiveTableScan(
9494
(_: Any, partitionKeys: Array[String]) => {
9595
val value = partitionKeys(ordinal)
9696
val dataType = relation.partitionKeys(ordinal).dataType
97-
castFromString(value, dataType)
97+
unwrapHiveData(castFromString(value, dataType))
9898
}
9999
} else {
100100
val ref = objectInspector.getAllStructFieldRefs
101101
.find(_.getFieldName == a.name)
102102
.getOrElse(sys.error(s"Can't find attribute $a"))
103103
(row: Any, _: Array[String]) => {
104104
val data = objectInspector.getStructFieldData(row, ref)
105-
unwrapData(data, ref.getFieldObjectInspector)
105+
unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
106106
}
107107
}
108108
}
109109
}
110110

111+
private def unwrapHiveData(value: Any) = value match {
112+
case maybeNull: String if maybeNull.toLowerCase == "null" => null
113+
case varchar: HiveVarchar => varchar.getValue
114+
case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue)
115+
case other => other
116+
}
117+
111118
private def castFromString(value: String, dataType: DataType) = {
112119
Cast(Literal(value), dataType).eval(null)
113120
}
@@ -143,20 +150,34 @@ case class HiveTableScan(
143150
}
144151

145152
def execute() = {
146-
inputRdd.map { row =>
147-
val values = row match {
148-
case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
149-
attributeFunctions.map(_(deserializedRow, partitionKeys))
150-
case deserializedRow: AnyRef =>
151-
attributeFunctions.map(_(deserializedRow, Array.empty))
153+
inputRdd.mapPartitions { iterator =>
154+
if (iterator.isEmpty) {
155+
Iterator.empty
156+
} else {
157+
val mutableRow = new GenericMutableRow(attributes.length)
158+
val buffered = iterator.buffered
159+
160+
(buffered.head match {
161+
case Array(_, _) =>
162+
buffered.map { case Array(deserializedRow, partitionKeys: Array[String]) =>
163+
(deserializedRow, partitionKeys)
164+
}
165+
166+
case _ =>
167+
buffered.map { deserializedRow =>
168+
(deserializedRow, Array.empty[String])
169+
}
170+
}).map { case (deserializedRow, partitionKeys: Array[String]) =>
171+
var i = 0
172+
173+
while (i < attributes.length) {
174+
mutableRow(i) = attributeFunctions(i)(deserializedRow, partitionKeys)
175+
i += 1
176+
}
177+
178+
mutableRow: Row
179+
}
152180
}
153-
buildRow(values.map {
154-
case n: String if n.toLowerCase == "null" => null
155-
case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
156-
case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
157-
BigDecimal(decimal.bigDecimalValue)
158-
case other => other
159-
})
160181
}
161182
}
162183

0 commit comments

Comments
 (0)