Skip to content

Commit 5683984

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-18016][SQL][FOLLOW-UP] Code Generation: Constant Pool Limit - reduce entries for mutable state
## What changes were proposed in this pull request? This PR addresses additional review comments in apache#19811 ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki <[email protected]> Closes apache#20036 from kiszk/SPARK-18066-followup.
1 parent 753793b commit 5683984

File tree

8 files changed

+51
-48
lines changed

8 files changed

+51
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class CodegenContext {
190190

191191
/**
192192
* Returns the reference of next available slot in current compacted array. The size of each
193-
* compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
193+
* compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
194194
* Once reaching the threshold, new compacted array is created.
195195
*/
196196
def getNextSlot(): String = {
@@ -352,7 +352,7 @@ class CodegenContext {
352352
def initMutableStates(): String = {
353353
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
354354
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
355-
val initCodes = mutableStateInitCode.distinct
355+
val initCodes = mutableStateInitCode.distinct.map(_ + "\n")
356356

357357
// The generated initialization code may exceed 64kb function size limit in JVM if there are too
358358
// many mutable states, so split it into multiple functions.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
118118
if (rVal != null) {
119119
val regexStr =
120120
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
121-
// inline mutable state since not many Like operations in a task
122121
val pattern = ctx.addMutableState(patternClass, "patternLike",
123-
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
122+
v => s"""$v = $patternClass.compile("$regexStr");""")
124123

125124
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
126125
val eval = left.genCode(ctx)
@@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
143142
val rightStr = ctx.freshName("rightStr")
144143
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
145144
s"""
146-
String $rightStr = ${eval2}.toString();
147-
${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
148-
${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
145+
String $rightStr = $eval2.toString();
146+
$patternClass $pattern = $patternClass.compile($escapeFunc($rightStr));
147+
${ev.value} = $pattern.matcher($eval1.toString()).matches();
149148
"""
150149
})
151150
}
@@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
194193
if (rVal != null) {
195194
val regexStr =
196195
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
197-
// inline mutable state since not many RLike operations in a task
198196
val pattern = ctx.addMutableState(patternClass, "patternRLike",
199-
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
197+
v => s"""$v = $patternClass.compile("$regexStr");""")
200198

201199
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
202200
val eval = left.genCode(ctx)
@@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
219217
val pattern = ctx.freshName("pattern")
220218
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
221219
s"""
222-
String $rightStr = ${eval2}.toString();
223-
${patternClass} $pattern = ${patternClass}.compile($rightStr);
224-
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
220+
String $rightStr = $eval2.toString();
221+
$patternClass $pattern = $patternClass.compile($rightStr);
222+
${ev.value} = $pattern.matcher($eval1.toString()).find(0);
225223
"""
226224
})
227225
}
@@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
338336

339337
nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
340338
s"""
341-
if (!$regexp.equals(${termLastRegex})) {
339+
if (!$regexp.equals($termLastRegex)) {
342340
// regex value changed
343-
${termLastRegex} = $regexp.clone();
344-
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
341+
$termLastRegex = $regexp.clone();
342+
$termPattern = $classNamePattern.compile($termLastRegex.toString());
345343
}
346-
if (!$rep.equals(${termLastReplacementInUTF8})) {
344+
if (!$rep.equals($termLastReplacementInUTF8)) {
347345
// replacement string changed
348-
${termLastReplacementInUTF8} = $rep.clone();
349-
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
346+
$termLastReplacementInUTF8 = $rep.clone();
347+
$termLastReplacement = $termLastReplacementInUTF8.toString();
350348
}
351-
$classNameStringBuffer ${termResult} = new $classNameStringBuffer();
352-
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
349+
$classNameStringBuffer $termResult = new $classNameStringBuffer();
350+
java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());
353351

354-
while (${matcher}.find()) {
355-
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
352+
while ($matcher.find()) {
353+
$matcher.appendReplacement($termResult, $termLastReplacement);
356354
}
357-
${matcher}.appendTail(${termResult});
358-
${ev.value} = UTF8String.fromString(${termResult}.toString());
359-
${termResult} = null;
355+
$matcher.appendTail($termResult);
356+
${ev.value} = UTF8String.fromString($termResult.toString());
357+
$termResult = null;
360358
$setEvNotNull
361359
"""
362360
})
@@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
425423

426424
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
427425
s"""
428-
if (!$regexp.equals(${termLastRegex})) {
426+
if (!$regexp.equals($termLastRegex)) {
429427
// regex value changed
430-
${termLastRegex} = $regexp.clone();
431-
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
428+
$termLastRegex = $regexp.clone();
429+
$termPattern = $classNamePattern.compile($termLastRegex.toString());
432430
}
433-
java.util.regex.Matcher ${matcher} =
434-
${termPattern}.matcher($subject.toString());
435-
if (${matcher}.find()) {
436-
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
437-
if (${matchResult}.group($idx) == null) {
431+
java.util.regex.Matcher $matcher =
432+
$termPattern.matcher($subject.toString());
433+
if ($matcher.find()) {
434+
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
435+
if ($matchResult.group($idx) == null) {
438436
${ev.value} = UTF8String.EMPTY_UTF8;
439437
} else {
440-
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
438+
${ev.value} = UTF8String.fromString($matchResult.group($idx));
441439
}
442440
$setEvNotNull
443441
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ case class SortExec(
138138
// Initialize the class member variables. This includes the instance of the Sorter and
139139
// the iterator to return sorted rows.
140140
val thisPlan = ctx.addReferenceObj("plan", this)
141-
// inline mutable state since not many Sort operations in a task
141+
// Inline mutable state since not many Sort operations in a task
142142
sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter",
143143
v => s"$v = $thisPlan.createSorter();", forceInline = true)
144144
val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics",

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
283283

284284
override def doProduce(ctx: CodegenContext): String = {
285285
// Right now, InputAdapter is only used when there is one input RDD.
286-
// inline mutable state since an inputAdaptor in a task
286+
// Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
287287
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
288288
forceInline = true)
289289
val row = ctx.freshName("row")

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,31 +587,35 @@ case class HashAggregateExec(
587587
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
588588
ctx.addInnerClass(generatedMap)
589589

590+
// Inline mutable state since not many aggregation operations in a task
590591
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap",
591-
v => s"$v = new $fastHashMapClassName();")
592-
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter")
592+
v => s"$v = new $fastHashMapClassName();", forceInline = true)
593+
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter",
594+
forceInline = true)
593595
} else {
594596
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
595597
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
596598
ctx.addInnerClass(generatedMap)
597599

600+
// Inline mutable state since not many aggregation operations in a task
598601
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap",
599602
v => s"$v = new $fastHashMapClassName(" +
600-
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());")
603+
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());",
604+
forceInline = true)
601605
ctx.addMutableState(
602606
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
603-
"fastHashMapIter")
607+
"fastHashMapIter", forceInline = true)
604608
}
605609
}
606610

607611
// Create a name for the iterator from the regular hash map.
608-
// inline mutable state since not many aggregation operations in a task
612+
// Inline mutable state since not many aggregation operations in a task
609613
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
610614
"mapIter", forceInline = true)
611615
// create hashMap
612616
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
613617
hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap",
614-
v => s"$v = $thisPlan.createHashMap();")
618+
v => s"$v = $thisPlan.createHashMap();", forceInline = true)
615619
sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter",
616620
forceInline = true)
617621

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ case class SampleExec(
284284
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
285285
val initSampler = ctx.freshName("initSampler")
286286

287-
// inline mutable state since not many Sample operations in a task
287+
// Inline mutable state since not many Sample operations in a task
288288
val sampler = ctx.addMutableState(s"$samplerClass<UnsafeRow>", "sampleReplace",
289289
v => {
290290
val initSamplerFuncName = ctx.addNewFunction(initSampler,
@@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
371371
val ev = ExprCode("", "false", value)
372372
val BigInt = classOf[java.math.BigInteger].getName
373373

374-
// inline mutable state since not many Range operations in a task
374+
// Inline mutable state since not many Range operations in a task
375375
val taskContext = ctx.addMutableState("TaskContext", "taskContext",
376376
v => s"$v = TaskContext.get();", forceInline = true)
377377
val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ case class BroadcastHashJoinExec(
139139
// At the end of the task, we update the avg hash probe.
140140
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
141141

142-
// inline mutable state since not many join operations in a task
142+
// Inline mutable state since not many join operations in a task
143143
val relationTerm = ctx.addMutableState(clsName, "relation",
144144
v => s"""
145145
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ case class SortMergeJoinExec(
422422
*/
423423
private def genScanner(ctx: CodegenContext): (String, String) = {
424424
// Create class member for next row from both sides.
425-
// inline mutable state since not many join operations in a task
425+
// Inline mutable state since not many join operations in a task
426426
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
427427
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)
428428

@@ -440,8 +440,9 @@ case class SortMergeJoinExec(
440440
val spillThreshold = getSpillThreshold
441441
val inMemoryThreshold = getInMemoryThreshold
442442

443+
// Inline mutable state since not many join operations in a task
443444
val matches = ctx.addMutableState(clsName, "matches",
444-
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);")
445+
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
445446
// Copy the left keys as class members so they could be used in next function call.
446447
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
447448

@@ -576,7 +577,7 @@ case class SortMergeJoinExec(
576577
override def needCopyResult: Boolean = true
577578

578579
override def doProduce(ctx: CodegenContext): String = {
579-
// inline mutable state since not many join operations in a task
580+
// Inline mutable state since not many join operations in a task
580581
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
581582
v => s"$v = inputs[0];", forceInline = true)
582583
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",

0 commit comments

Comments
 (0)