Skip to content

Commit c194d5e

Browse files
committed
add metadata field to StructField and Attribute
1 parent 5044e49 commit c194d5e

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ abstract class NamedExpression extends Expression {
4040
def name: String
4141
def exprId: ExprId
4242
def qualifiers: Seq[String]
43+
def metadata: Map[String, Any] = Map.empty
4344

4445
def toAttribute: Attribute
4546

@@ -112,9 +113,13 @@ case class Alias(child: Expression, name: String)
112113
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
113114
* tableName and subQueryAlias are possible qualifiers.
114115
*/
115-
case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true)
116-
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
117-
extends Attribute with trees.LeafNode[Expression] {
116+
case class AttributeReference(
117+
name: String,
118+
dataType: DataType,
119+
nullable: Boolean = true,
120+
override val metadata: Map[String, Any] = Map.empty)(
121+
val exprId: ExprId = NamedExpression.newExprId,
122+
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
118123

119124
override def references = AttributeSet(this :: Nil)
120125

@@ -131,7 +136,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
131136
h
132137
}
133138

134-
override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
139+
override def newInstance =
140+
AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
135141

136142
/**
137143
* Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -140,7 +146,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
140146
if (nullable == newNullability) {
141147
this
142148
} else {
143-
AttributeReference(name, dataType, newNullability)(exprId, qualifiers)
149+
AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers)
144150
}
145151
}
146152

@@ -151,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
151157
if (newQualifiers == qualifiers) {
152158
this
153159
} else {
154-
AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
160+
AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers)
155161
}
156162
}
157163

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,14 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
296296
* @param name The name of this field.
297297
* @param dataType The data type of this field.
298298
* @param nullable Indicates if values of this field can be `null` values.
299+
* @param metadata The metadata of this field, which is a map from string to simple type that can be
300+
* serialized to JSON automatically.
299301
*/
300-
case class StructField(name: String, dataType: DataType, nullable: Boolean) {
302+
case class StructField(
303+
name: String,
304+
dataType: DataType,
305+
nullable: Boolean,
306+
metadata: Map[String, Any] = Map.empty) {
301307

302308
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
303309
builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
@@ -307,7 +313,7 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) {
307313

308314
object StructType {
309315
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
310-
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
316+
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
311317
}
312318

313319
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -342,7 +348,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
342348
}
343349

344350
protected[sql] def toAttributes =
345-
fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
351+
fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
346352

347353
def treeString: String = {
348354
val builder = new StringBuilder

sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,15 @@ private[sql] object JsonRDD extends Logging {
112112
}
113113
}.flatMap(field => field).toSeq
114114

115-
StructType(
116-
(topLevelFields ++ structFields).sortBy {
117-
case StructField(name, _, _) => name
118-
})
115+
StructType((topLevelFields ++ structFields).sortBy(_.name))
119116
}
120117

121118
makeStruct(resolved.keySet.toSeq, Nil)
122119
}
123120

124121
private[sql] def nullTypeToStringType(struct: StructType): StructType = {
125122
val fields = struct.fields.map {
126-
case StructField(fieldName, dataType, nullable) => {
123+
case StructField(fieldName, dataType, nullable, _) => {
127124
val newType = dataType match {
128125
case NullType => StringType
129126
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
@@ -158,9 +155,7 @@ private[sql] object JsonRDD extends Logging {
158155
StructField(name, dataType, true)
159156
}
160157
}
161-
StructType(newFields.toSeq.sortBy {
162-
case StructField(name, _, _) => name
163-
})
158+
StructType(newFields.toSeq.sortBy(_.name))
164159
}
165160
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
166161
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
@@ -385,7 +380,7 @@ private[sql] object JsonRDD extends Logging {
385380
// TODO: Reuse the row instead of creating a new one for every record.
386381
val row = new GenericMutableRow(schema.fields.length)
387382
schema.fields.zipWithIndex.foreach {
388-
case (StructField(name, dataType, _), i) =>
383+
case (StructField(name, dataType, _, _), i) =>
389384
row.update(i, json.get(name).flatMap(v => Option(v)).map(
390385
enforceCorrectType(_, dataType)).orNull)
391386
}

0 commit comments

Comments
 (0)