Skip to content

Commit 76bf6b2

Browse files
rshkvj-essejdcasale
committed
[SPARK-26626][SQL] Maximum size for repeatedly substituted aliases in SQL expressions
We have internal applications (BS and C) prone to OOMs with repeated use of aliases. See ticket [1] and upstream PR [2]. [1] https://issues.apache.org/jira/browse/SPARK-26626 [2] apache#23556 Co-authored-by: j-esse <[email protected]> Co-authored-by: Josh Casale <[email protected]> Co-authored-by: Will Raschkowski <[email protected]>
1 parent 82c7955 commit 76bf6b2

File tree

5 files changed

+84
-3
lines changed

5 files changed

+84
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,8 @@ object CollapseProject extends Rule[LogicalPlan] {
709709

710710
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
711711
case p1 @ Project(_, p2: Project) =>
712-
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
712+
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
713+
hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) {
713714
p1
714715
} else {
715716
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
@@ -753,6 +754,28 @@ object CollapseProject extends Rule[LogicalPlan] {
753754
}.exists(!_.deterministic))
754755
}
755756

757+
private def hasOversizedRepeatedAliases(
758+
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
759+
val aliases = collectAliases(lower)
760+
761+
// Count how many times each alias is used in the upper Project.
762+
// If an alias is only used once, we can safely substitute it without increasing the overall
763+
// tree size
764+
val referenceCounts = AttributeMap(
765+
upper
766+
.flatMap(_.collect { case a: Attribute => a })
767+
.groupBy(identity)
768+
.mapValues(_.size).toSeq
769+
)
770+
771+
// Check for any aliases that are used more than once, and are larger than the configured
772+
// maximum size
773+
aliases.exists({ case (attribute, expression) =>
774+
referenceCounts.getOrElse(attribute, 0) > 1 &&
775+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
776+
})
777+
}
778+
756779
private def buildCleanedProjectList(
757780
upper: Seq[NamedExpression],
758781
lower: Seq[NamedExpression]): Seq[NamedExpression] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26+
import org.apache.spark.sql.internal.SQLConf
2627

2728
trait OperationHelper {
2829
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
@@ -62,6 +63,26 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
6263
Some((fields.getOrElse(child.output), filters, child))
6364
}
6465

66+
private def hasOversizedRepeatedAliases(fields: Seq[Expression],
67+
aliases: Map[Attribute, Expression]): Boolean = {
68+
// Count how many times each alias is used in the fields.
69+
// If an alias is only used once, we can safely substitute it without increasing the overall
70+
// tree size
71+
val referenceCounts = AttributeMap(
72+
fields
73+
.flatMap(_.collect { case a: Attribute => a })
74+
.groupBy(identity)
75+
.mapValues(_.size).toSeq
76+
)
77+
78+
// Check for any aliases that are used more than once, and are larger than the configured
79+
// maximum size
80+
aliases.exists({ case (attribute, expression) =>
81+
referenceCounts.getOrElse(attribute, 0) > 1 &&
82+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
83+
})
84+
}
85+
6586
/**
6687
* Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
6788
* Here are two examples for alias in-lining/substitution.
@@ -81,8 +102,13 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
81102
plan match {
82103
case Project(fields, child) if fields.forall(_.deterministic) =>
83104
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
84-
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
85-
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
105+
if (hasOversizedRepeatedAliases(fields, aliases)) {
106+
// Skip substitution if it could overly increase the overall tree size and risk OOMs
107+
(None, Nil, plan, AttributeMap(Nil))
108+
} else {
109+
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
110+
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
111+
}
86112

87113
case Filter(condition, child) if condition.deterministic =>
88114
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
120120

121121
lazy val containsChild: Set[TreeNode[_]] = children.toSet
122122

123+
lazy val treeSize: Long = children.map(_.treeSize).sum + 1
124+
123125
// Copied from Scala 2.13.1
124126
// github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73
125127
// to prevent the issue https://github.com/scala/bug/issues/10495

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,6 +2373,15 @@ object SQLConf {
23732373
.booleanConf
23742374
.createWithDefault(false)
23752375

2376+
val MAX_REPEATED_ALIAS_SIZE =
2377+
buildConf("spark.sql.maxRepeatedAliasSize")
2378+
.internal()
2379+
.doc("The maximum size of alias expression that will be substituted multiple times " +
2380+
"(size defined by the number of nodes in the expression tree). " +
2381+
"Used by the CollapseProject optimizer, and PhysicalOperation.")
2382+
.intConf
2383+
.createWithDefault(100)
2384+
23762385
val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength")
23772386
.doc("The max length of a file that can be read by the binary file data source. " +
23782387
"Spark will fail fast and not attempt to read the file if its length exceeds this value. " +
@@ -3202,6 +3211,8 @@ class SQLConf extends Serializable with Logging {
32023211
def setCommandRejectsSparkCoreConfs: Boolean =
32033212
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
32043213

3214+
def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE)
3215+
32053216
def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING)
32063217

32073218
def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,23 @@ class CollapseProjectSuite extends PlanTest {
170170
val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze
171171
comparePlans(optimized, expected)
172172
}
173+
174+
175+
test("ensure oversize aliases are not repeatedly substituted") {
176+
var query: LogicalPlan = testRelation
177+
for( a <- 1 to 100) {
178+
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
179+
}
180+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
181+
assert(projects.size >= 12)
182+
}
183+
184+
test("ensure oversize aliases are still substituted once") {
185+
var query: LogicalPlan = testRelation
186+
for( a <- 1 to 20) {
187+
query = query.select(('a + 'b).as('a), 'b)
188+
}
189+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
190+
assert(projects.size === 1)
191+
}
173192
}

0 commit comments

Comments
 (0)