@@ -26,8 +26,7 @@ import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
26
26
import org .apache .hadoop .hive .serde .serdeConstants
27
27
import org .apache .hadoop .hive .serde2 .objectinspector .ObjectInspectorUtils .ObjectInspectorCopyOption
28
28
import org .apache .hadoop .hive .serde2 .objectinspector ._
29
- import org .apache .hadoop .hive .serde2 .objectinspector .primitive .JavaHiveDecimalObjectInspector
30
- import org .apache .hadoop .hive .serde2 .objectinspector .primitive .JavaHiveVarcharObjectInspector
29
+ import org .apache .hadoop .hive .serde2 .objectinspector .primitive ._
31
30
import org .apache .hadoop .hive .serde2 .typeinfo .TypeInfoUtils
32
31
import org .apache .hadoop .hive .serde2 .{ColumnProjectionUtils , Serializer }
33
32
import org .apache .hadoop .io .Writable
@@ -95,29 +94,34 @@ case class HiveTableScan(
95
94
attributes.map { a =>
96
95
val ordinal = relation.partitionKeys.indexOf(a)
97
96
if (ordinal >= 0 ) {
97
+ val dataType = relation.partitionKeys(ordinal).dataType
98
98
(_ : Any , partitionKeys : Array [String ]) => {
99
- val value = partitionKeys(ordinal)
100
- val dataType = relation.partitionKeys(ordinal).dataType
101
- unwrapHiveData(castFromString(value, dataType))
99
+ castFromString(partitionKeys(ordinal), dataType)
102
100
}
103
101
} else {
104
102
val ref = objectInspector.getAllStructFieldRefs
105
103
.find(_.getFieldName == a.name)
106
104
.getOrElse(sys.error(s " Can't find attribute $a" ))
105
+ val fieldObjectInspector = ref.getFieldObjectInspector
106
+
107
+ val unwrapHiveData = fieldObjectInspector match {
108
+ case _ : HiveVarcharObjectInspector =>
109
+ (value : Any ) => value.asInstanceOf [HiveVarchar ].getValue
110
+ case _ : HiveDecimalObjectInspector =>
111
+ (value : Any ) => BigDecimal (value.asInstanceOf [HiveDecimal ].bigDecimalValue())
112
+ case _ =>
113
+ identity[Any ] _
114
+ }
115
+
107
116
(row : Any , _ : Array [String ]) => {
108
117
val data = objectInspector.getStructFieldData(row, ref)
109
- unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
118
+ val hiveData = unwrapData(data, fieldObjectInspector)
119
+ if (hiveData != null ) unwrapHiveData(hiveData) else null
110
120
}
111
121
}
112
122
}
113
123
}
114
124
115
- private def unwrapHiveData (value : Any ) = value match {
116
- case varchar : HiveVarchar => varchar.getValue
117
- case decimal : HiveDecimal => BigDecimal (decimal.bigDecimalValue)
118
- case other => other
119
- }
120
-
121
125
private def castFromString (value : String , dataType : DataType ) = {
122
126
Cast (Literal (value), dataType).eval(null )
123
127
}
0 commit comments