|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.catalyst.expressions |
| 19 | + |
| 20 | +import scala.collection.Map |
| 21 | + |
| 22 | +import org.apache.spark.sql.AnalysisException |
| 23 | +import org.apache.spark.sql.catalyst.analysis._ |
| 24 | +import org.apache.spark.sql.types._ |
| 25 | + |
| 26 | +object ExtractValue { |
| 27 | + /** |
| 28 | + * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`, |
| 29 | + * depend on the type of `child` and `extraction`. |
| 30 | + * |
| 31 | + * `child` | `extraction` | concrete `ExtractValue` |
| 32 | + * ---------------------------------------------------------------- |
| 33 | + * Struct | Literal String | GetStructField |
| 34 | + * Array[Struct] | Literal String | GetArrayStructFields |
| 35 | + * Array | Integral type | GetArrayItem |
| 36 | + * Map | Any type | GetMapValue |
| 37 | + */ |
| 38 | + def apply( |
| 39 | + child: Expression, |
| 40 | + extraction: Expression, |
| 41 | + resolver: Resolver): ExtractValue = { |
| 42 | + |
| 43 | + (child.dataType, extraction) match { |
| 44 | + case (StructType(fields), Literal(fieldName, StringType)) => |
| 45 | + val ordinal = findField(fields, fieldName.toString, resolver) |
| 46 | + GetStructField(child, fields(ordinal), ordinal) |
| 47 | + case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => |
| 48 | + val ordinal = findField(fields, fieldName.toString, resolver) |
| 49 | + GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) |
| 50 | + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => |
| 51 | + GetArrayItem(child, extraction) |
| 52 | + case (_: MapType, _) => |
| 53 | + GetMapValue(child, extraction) |
| 54 | + case (otherType, _) => |
| 55 | + val errorMsg = otherType match { |
| 56 | + case StructType(_) | ArrayType(StructType(_), _) => |
| 57 | + s"Field name should be String Literal, but it's $extraction" |
| 58 | + case _: ArrayType => |
| 59 | + s"Array index should be integral type, but it's ${extraction.dataType}" |
| 60 | + case other => |
| 61 | + s"Can't extract value from $child" |
| 62 | + } |
| 63 | + throw new AnalysisException(errorMsg) |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + def unapply(g: ExtractValue): Option[(Expression, Expression)] = { |
| 68 | + g match { |
| 69 | + case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) |
| 70 | + case _ => Some((g.child, null)) |
| 71 | + } |
| 72 | + } |
| 73 | + |
| 74 | + /** |
| 75 | + * Find the ordinal of StructField, report error if no desired field or over one |
| 76 | + * desired fields are found. |
| 77 | + */ |
| 78 | + private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = { |
| 79 | + val checkField = (f: StructField) => resolver(f.name, fieldName) |
| 80 | + val ordinal = fields.indexWhere(checkField) |
| 81 | + if (ordinal == -1) { |
| 82 | + throw new AnalysisException( |
| 83 | + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") |
| 84 | + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { |
| 85 | + throw new AnalysisException( |
| 86 | + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") |
| 87 | + } else { |
| 88 | + ordinal |
| 89 | + } |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +trait ExtractValue extends UnaryExpression { |
| 94 | + self: Product => |
| 95 | + |
| 96 | + type EvaluatedType = Any |
| 97 | +} |
| 98 | + |
| 99 | +/** |
| 100 | + * Returns the value of fields in the Struct `child`. |
| 101 | + */ |
| 102 | +case class GetStructField(child: Expression, field: StructField, ordinal: Int) |
| 103 | + extends ExtractValue { |
| 104 | + |
| 105 | + override def dataType: DataType = field.dataType |
| 106 | + override def nullable: Boolean = child.nullable || field.nullable |
| 107 | + override def foldable: Boolean = child.foldable |
| 108 | + override def toString: String = s"$child.${field.name}" |
| 109 | + |
| 110 | + override def eval(input: Row): Any = { |
| 111 | + val baseValue = child.eval(input).asInstanceOf[Row] |
| 112 | + if (baseValue == null) null else baseValue(ordinal) |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +/** |
| 117 | + * Returns the array of value of fields in the Array of Struct `child`. |
| 118 | + */ |
| 119 | +case class GetArrayStructFields( |
| 120 | + child: Expression, |
| 121 | + field: StructField, |
| 122 | + ordinal: Int, |
| 123 | + containsNull: Boolean) extends ExtractValue { |
| 124 | + |
| 125 | + override def dataType: DataType = ArrayType(field.dataType, containsNull) |
| 126 | + override def nullable: Boolean = child.nullable |
| 127 | + override def foldable: Boolean = child.foldable |
| 128 | + override def toString: String = s"$child.${field.name}" |
| 129 | + |
| 130 | + override def eval(input: Row): Any = { |
| 131 | + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] |
| 132 | + if (baseValue == null) null else { |
| 133 | + baseValue.map { row => |
| 134 | + if (row == null) null else row(ordinal) |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +abstract class ExtractValueWithOrdinal extends ExtractValue { |
| 141 | + self: Product => |
| 142 | + |
| 143 | + def ordinal: Expression |
| 144 | + |
| 145 | + /** `Null` is returned for invalid ordinals. */ |
| 146 | + override def nullable: Boolean = true |
| 147 | + override def foldable: Boolean = child.foldable && ordinal.foldable |
| 148 | + override def toString: String = s"$child[$ordinal]" |
| 149 | + override def children: Seq[Expression] = child :: ordinal :: Nil |
| 150 | + |
| 151 | + override def eval(input: Row): Any = { |
| 152 | + val value = child.eval(input) |
| 153 | + if (value == null) { |
| 154 | + null |
| 155 | + } else { |
| 156 | + val o = ordinal.eval(input) |
| 157 | + if (o == null) { |
| 158 | + null |
| 159 | + } else { |
| 160 | + evalNotNull(value, o) |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + protected def evalNotNull(value: Any, ordinal: Any): Any |
| 166 | +} |
| 167 | + |
| 168 | +/** |
| 169 | + * Returns the field at `ordinal` in the Array `child` |
| 170 | + */ |
| 171 | +case class GetArrayItem(child: Expression, ordinal: Expression) |
| 172 | + extends ExtractValueWithOrdinal { |
| 173 | + |
| 174 | + override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType |
| 175 | + |
| 176 | + override lazy val resolved = childrenResolved && |
| 177 | + child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] |
| 178 | + |
| 179 | + protected def evalNotNull(value: Any, ordinal: Any) = { |
| 180 | + // TODO: consider using Array[_] for ArrayType child to avoid |
| 181 | + // boxing of primitives |
| 182 | + val baseValue = value.asInstanceOf[Seq[_]] |
| 183 | + val index = ordinal.asInstanceOf[Int] |
| 184 | + if (index >= baseValue.size || index < 0) { |
| 185 | + null |
| 186 | + } else { |
| 187 | + baseValue(index) |
| 188 | + } |
| 189 | + } |
| 190 | +} |
| 191 | + |
| 192 | +/** |
| 193 | + * Returns the value of key `ordinal` in Map `child` |
| 194 | + */ |
| 195 | +case class GetMapValue(child: Expression, ordinal: Expression) |
| 196 | + extends ExtractValueWithOrdinal { |
| 197 | + |
| 198 | + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType |
| 199 | + |
| 200 | + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] |
| 201 | + |
| 202 | + protected def evalNotNull(value: Any, ordinal: Any) = { |
| 203 | + val baseValue = value.asInstanceOf[Map[Any, _]] |
| 204 | + baseValue.get(ordinal).orNull |
| 205 | + } |
| 206 | +} |
0 commit comments