Skip to content

Commit 1c057f5

Browse files
aokolnychyicloud-fan
authored andcommitted
[SPARK-42151][SQL] Align UPDATE assignments with table attributes
### What changes were proposed in this pull request? This PR adds a rule to align UPDATE assignments with table attributes. ### Why are the changes needed? These changes are needed so that we can rewrite UPDATE statements into executable plans for tables that support row-level operations. In particular, our row-level mutation framework assumes Spark is responsible for building an updated version of each affected row and that row is passed back to the data source. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with tests. Closes apache#40308 from aokolnychyi/spark-42151-v2. Authored-by: aokolnychyi <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 66392c4 commit 1c057f5

File tree

10 files changed

+1182
-55
lines changed

10 files changed

+1182
-55
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@
353353
"The <functionName> does not support ordering on type <dataType>."
354354
]
355355
},
356+
"INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS" : {
357+
"message" : [
358+
"<errors>"
359+
]
360+
},
356361
"IN_SUBQUERY_DATA_TYPE_MISMATCH" : {
357362
"message" : [
358363
"The data type of one or more elements in the left hand side of an IN subquery is not compatible with the data type of the output of the subquery. Mismatched columns: [<mismatchedColumns>], left side: [<leftType>], right side: [<rightType>]."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
320320
ResolveRandomSeed ::
321321
ResolveBinaryArithmetic ::
322322
ResolveUnion ::
323+
ResolveRowLevelCommandAssignments ::
323324
RewriteDeleteFromTable ::
324325
typeCoercionRules ++
325326
Seq(
@@ -3329,43 +3330,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
33293330
} else {
33303331
v2Write
33313332
}
3332-
3333-
case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
3334-
resolveAssignments(u)
3335-
3336-
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
3337-
resolveAssignments(m)
3338-
}
3339-
3340-
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
3341-
p.transformExpressions {
3342-
case assignment: Assignment =>
3343-
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
3344-
AssertNotNull(assignment.value)
3345-
} else {
3346-
assignment.value
3347-
}
3348-
val casted = if (assignment.key.dataType != nullHandled.dataType) {
3349-
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
3350-
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
3351-
cast
3352-
} else {
3353-
nullHandled
3354-
}
3355-
val rawKeyType = assignment.key.transform {
3356-
case a: AttributeReference =>
3357-
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
3358-
}.dataType
3359-
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
3360-
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
3361-
} else {
3362-
casted
3363-
}
3364-
val cleanedKey = assignment.key.transform {
3365-
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
3366-
}
3367-
Assignment(cleanedKey, finalValue)
3368-
}
33693333
}
33703334
}
33713335

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.SQLConfHelper
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
24+
import org.apache.spark.sql.catalyst.plans.logical.Assignment
25+
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
26+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
27+
import org.apache.spark.sql.errors.QueryCompilationErrors
28+
import org.apache.spark.sql.types.{DataType, StructType}
29+
30+
object AssignmentUtils extends SQLConfHelper with CastSupport {
31+
32+
/**
33+
* Aligns assignments to match table columns.
34+
* <p>
35+
* This method processes and reorders given assignments so that each target column gets
36+
* an expression it should be set to. If a column does not have a matching assignment,
37+
* it will be set to its current value. For example, if one passes table attributes c1, c2
38+
* and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. This allows Spark to
39+
* construct an updated version of a row.
40+
* <p>
41+
* This method also handles updates to nested columns. If there is an assignment to a particular
42+
* nested field, this method will construct a new struct with one field updated preserving other
43+
* fields that have not been modified. For example, if one passes table attributes c1, c2
44+
* where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, this method will
45+
* return c1 = c1, c2 = struct(c2.n1, 1).
46+
*
47+
* @param attrs table attributes
48+
* @param assignments assignments to align
49+
* @return aligned assignments that match table attributes
50+
*/
51+
def alignAssignments(
52+
attrs: Seq[Attribute],
53+
assignments: Seq[Assignment]): Seq[Assignment] = {
54+
55+
val errors = new mutable.ArrayBuffer[String]()
56+
57+
val output = attrs.map { attr =>
58+
applyAssignments(
59+
col = restoreActualType(attr),
60+
colExpr = attr,
61+
assignments,
62+
addError = err => errors += err,
63+
colPath = Seq(attr.name))
64+
}
65+
66+
if (errors.nonEmpty) {
67+
throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq)
68+
}
69+
70+
attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) }
71+
}
72+
73+
private def restoreActualType(attr: Attribute): Attribute = {
74+
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
75+
}
76+
77+
private def applyAssignments(
78+
col: Attribute,
79+
colExpr: Expression,
80+
assignments: Seq[Assignment],
81+
addError: String => Unit,
82+
colPath: Seq[String]): Expression = {
83+
84+
val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
85+
assignment.key.semanticEquals(colExpr)
86+
}
87+
88+
val fieldAssignments = otherAssignments.filter { assignment =>
89+
assignment.key.exists(_.semanticEquals(colExpr))
90+
}
91+
92+
if (exactAssignments.size > 1) {
93+
val conflictingValuesStr = exactAssignments.map(_.value.sql).mkString(", ")
94+
addError(s"Multiple assignments for '${colPath.quoted}': $conflictingValuesStr")
95+
colExpr
96+
} else if (exactAssignments.nonEmpty && fieldAssignments.nonEmpty) {
97+
val conflictingAssignments = exactAssignments ++ fieldAssignments
98+
val conflictingAssignmentsStr = conflictingAssignments.map(_.sql).mkString(", ")
99+
addError(s"Conflicting assignments for '${colPath.quoted}': $conflictingAssignmentsStr")
100+
colExpr
101+
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
102+
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
103+
} else if (exactAssignments.nonEmpty) {
104+
val value = exactAssignments.head.value
105+
TableOutputResolver.resolveUpdate(value, col, conf, addError, colPath)
106+
} else {
107+
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath)
108+
}
109+
}
110+
111+
private def applyFieldAssignments(
112+
col: Attribute,
113+
colExpr: Expression,
114+
assignments: Seq[Assignment],
115+
addError: String => Unit,
116+
colPath: Seq[String]): Expression = {
117+
118+
col.dataType match {
119+
case structType: StructType =>
120+
val fieldAttrs = structType.toAttributes
121+
val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) =>
122+
GetStructField(colExpr, ordinal, Some(field.name))
123+
}
124+
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) =>
125+
applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name)
126+
}
127+
toNamedStruct(structType, updatedFieldExprs)
128+
129+
case otherType =>
130+
addError(
131+
"Updating nested fields is only supported for StructType but " +
132+
s"'${colPath.quoted}' is of type $otherType")
133+
colExpr
134+
}
135+
}
136+
137+
private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = {
138+
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) =>
139+
Seq(Literal(field.name), expr)
140+
}
141+
CreateNamedStruct(namedStructExprs)
142+
}
143+
144+
/**
145+
* Checks whether assignments are aligned and compatible with table columns.
146+
*
147+
* @param attrs table attributes
148+
* @param assignments assignments to check
149+
* @return true if the assignments are aligned
150+
*/
151+
def aligned(attrs: Seq[Attribute], assignments: Seq[Assignment]): Boolean = {
152+
if (attrs.size != assignments.size) {
153+
return false
154+
}
155+
156+
attrs.zip(assignments).forall { case (attr, assignment) =>
157+
val attrType = CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)
158+
val isMatchingAssignment = assignment.key match {
159+
case key: Attribute if conf.resolver(key.name, attr.name) => true
160+
case _ => false
161+
}
162+
isMatchingAssignment &&
163+
DataType.equalsIgnoreCompatibleNullability(assignment.value.dataType, attrType) &&
164+
(attr.nullable || !assignment.value.nullable)
165+
}
166+
}
167+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast}
21+
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
22+
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable}
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
25+
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
26+
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
27+
import org.apache.spark.sql.errors.QueryCompilationErrors
28+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
29+
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
30+
31+
/**
32+
* A rule that resolves assignments in row-level commands.
33+
*
34+
* Note that this rule must be run before rewriting row-level commands into executable plans.
35+
* This rule does not apply to tables that accept any schema. Such tables must inject their own
36+
* rules to resolve assignments.
37+
*/
38+
object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
39+
40+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
41+
_.containsPattern(COMMAND), ruleId) {
42+
case u: UpdateTable if !u.skipSchemaResolution && u.resolved &&
43+
supportsRowLevelOperations(u.table) && !u.aligned =>
44+
validateStoreAssignmentPolicy()
45+
val newTable = u.table.transform {
46+
case r: DataSourceV2Relation =>
47+
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
48+
}
49+
val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments)
50+
u.copy(table = newTable, assignments = newAssignments)
51+
52+
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
53+
resolveAssignments(u)
54+
55+
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
56+
resolveAssignments(m)
57+
}
58+
59+
private def validateStoreAssignmentPolicy(): Unit = {
60+
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2
61+
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
62+
throw QueryCompilationErrors.legacyStoreAssignmentPolicyError()
63+
}
64+
}
65+
66+
private def supportsRowLevelOperations(table: LogicalPlan): Boolean = {
67+
EliminateSubqueryAliases(table) match {
68+
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true
69+
case _ => false
70+
}
71+
}
72+
73+
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
74+
p.transformExpressions {
75+
case assignment: Assignment =>
76+
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
77+
AssertNotNull(assignment.value)
78+
} else {
79+
assignment.value
80+
}
81+
val casted = if (assignment.key.dataType != nullHandled.dataType) {
82+
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
83+
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
84+
cast
85+
} else {
86+
nullHandled
87+
}
88+
val rawKeyType = assignment.key.transform {
89+
case a: AttributeReference =>
90+
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
91+
}.dataType
92+
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
93+
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
94+
} else {
95+
casted
96+
}
97+
val cleanedKey = assignment.key.transform {
98+
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
99+
}
100+
Assignment(cleanedKey, finalValue)
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)