diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e9..4568714933095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,8 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.freshName("globalIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 821d784a01342..784eaf8195194 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -65,10 +65,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.freshName("count") - val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm) - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm) + val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask") ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 4fa18d6b3209b..736ca37c6d54a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -43,8 +43,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm) + val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId") ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1893eec22b65d..d3a8cb5804717 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("leastTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} @@ -683,8 +682,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("greatestTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3a03a65e1af92..41a920ba3d677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -137,22 +137,63 @@ class CodegenContext { var currentVars: Seq[ExprCode] = null /** - * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a - * 3-tuple: java type, variable name, code to init it. - * As an example, ("int", "count", "count = 0;") will produce code: + * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a + * 2-tuple: java type, variable name. + * As an example, ("int", "count") will produce code: * {{{ * private int count; * }}} - * as a member variable, and add - * {{{ - * count = 0; - * }}} - * to the constructor. + * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = + mutable.ArrayBuffer.empty[(String, String)] + + /** + * The mapping between mutable state types and corrseponding compacted arrays. + * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates + * the compacted arrays for the mutable states with the same java type. + */ + val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = + mutable.Map.empty[String, MutableStateArrays] + + // An array holds the code that will initialize each state + val mutableStateInitCode: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] + + /** + * This class holds a set of names of mutableStateArrays that is used for compacting mutable + * states for a certain type, and holds the next available slot of the current compacted array. + */ + class MutableStateArrays { + val arrayNames = mutable.ListBuffer.empty[String] + createNewArray() + + private[this] var currentIndex = 0 + + private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + + def getCurrentIndex: Int = currentIndex + + /** + * Returns the reference of next available slot in current compacted array. The size of each + * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * Once reaching the threshold, new compacted array is created. + */ + def getNextSlot(): String = { + if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + val res = s"${arrayNames.last}[$currentIndex]" + currentIndex += 1 + res + } else { + createNewArray() + currentIndex = 1 + s"${arrayNames.last}[0]" + } + } + + } /** * Add a mutable state as a field to the generated class. c.f. the comments above. @@ -163,11 +204,52 @@ class CodegenContext { * the list of default imports available. * Also, generic type arguments are accepted but ignored. * @param variableName Name of the field. - * @param initCode The statement(s) to put into the init() method to initialize this field. + * @param initFunc Function includes statement(s) to put into the init() method to initialize + * this field. The argument is the name of the mutable state variable. * If left blank, the field will be default-initialized. + * @param forceInline whether the declaration and initialization code may be inlined rather than + * compacted. Please set `true` into forceInline for one of the followings: + * 1. use the original name of the status + * 2. expect to non-frequently generate the status + * (e.g. not much sort operators in one stage) + * @param useFreshName If this is false and the mutable state ends up inlining in the outer + * class, the name is not changed + * @return the name of the mutable state variable, which is the original name or fresh name if + * the variable is inlined to the outer class, or an array access if the variable is to + * be stored in an array of variables of the same type. + * A variable will be inlined into the outer class when one of the following conditions + * are satisfied: + * 1. forceInline is true + * 2. its type is primitive type and the total number of the inlined mutable variables + * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * 3. its type is multi-dimensional array + * When a variable is compacted into an array, the max size of the array for compaction + * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. */ - def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = { - mutableStates += ((javaType, variableName, initCode)) + def addMutableState( + javaType: String, + variableName: String, + initFunc: String => String = _ => "", + forceInline: Boolean = false, + useFreshName: Boolean = true): String = { + + // want to put a primitive type variable at outerClass for performance + val canInlinePrimitive = isPrimitiveType(javaType) && + (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { + val varName = if (useFreshName) freshName(variableName) else variableName + val initCode = initFunc(varName) + inlinedMutableStates += ((javaType, varName)) + mutableStateInitCode += initCode + varName + } else { + val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) + val element = arrays.getNextSlot() + + val initCode = initFunc(element) + mutableStateInitCode += initCode + element + } } /** @@ -176,8 +258,7 @@ class CodegenContext { * data types like: UTF8String, ArrayData, MapData & InternalRow. */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { - val value = freshName(variableName) - addMutableState(javaType(dataType), value, "") + val value = addMutableState(javaType(dataType), variableName) val code = dataType match { case StringType => s"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" @@ -189,15 +270,37 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) => s"private $javaType $variableName;" - }.mkString("\n") + } + + val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) => + val numArrays = mutableStateArrays.arrayNames.size + mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) => + val length = if (index + 1 == numArrays) { + mutableStateArrays.getCurrentIndex + } else { + CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + } + if (javaType.contains("[]")) { + // initializer had an one-dimensional array variable + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + // initializer had a scalar variable + s"private $javaType[] $arrayName = new $javaType[$length];" + } + } + } + + (inlinedStates ++ arrayStates).mkString("\n") } def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStates.distinct.map(_._3 + "\n") + val initCodes = mutableStateInitCode.distinct + // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) @@ -1011,9 +1114,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach { e => val expr = e.head - val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" - val value = s"${fnName}Value" + val fnName = freshName("subExpr") + val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val value = addMutableState(javaType(expr.dataType), "subExprValue") // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) @@ -1039,9 +1142,6 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) @@ -1165,6 +1265,15 @@ object CodeGenerator extends Logging { // class. val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + // This is the threshold for the number of global variables, whose types are primitive type or + // complex type (e.g. more than one-dimensional array), that will be placed at the outer class + val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + + // This is the maximum number of array elements to keep global variables in one Java array + // 32767 is the maximum integer value that does not require a constant pool entry in a Java + // bytecode instruction + val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index bd8312eb8b7fe..b53c0087e7e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -57,41 +57,37 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case _ => true }.unzip val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) - val projectionCodes = exprVals.zip(index).map { + + // 4-tuples: (code for projection, isNull variable name, value variable name, column index) + val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) + val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") if (e.nullable) { - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; - """ + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + (s""" + |${ev.code} + |$isNull = ${ev.isNull}; + |$value = ${ev.value}; + """.stripMargin, isNull, value, i) } else { - val value = s"value_$i" - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${ev.code} - $value = ${ev.value}; - """ + (s""" + |${ev.code} + |$value = ${ev.value}; + """.stripMargin, ev.isNull, value, i) } } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(index).map { - case (e, i) => - val ev = ExprCode("", s"isNull_$i", s"value_$i") + val updates = validExpr.zip(projectionCodes).map { + case (e, (_, isNull, value, i)) => + val ev = ExprCode("", isNull, value) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) val codeBody = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b022457865d50..36ffa8dcdd2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -73,9 +73,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String, isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, - s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -186,9 +185,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.freshName("arrayWriter") - ctx.addMutableState(arrayWriterClass, arrayWriter, - s"$arrayWriter = new $arrayWriterClass();") + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -318,13 +316,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + val result = ctx.addMutableState("UnsafeRow", "result", + v => s"$v = new UnsafeRow(${expressions.length});") - val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass($result, ${numVarLenFields * 32});") + val holder = ctx.addMutableState(holderClass, "holder", + v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 53c3b226895ec..1a9b68222a7f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,8 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.freshName("caseWhenTmpResult") - ctx.addMutableState(ctx.javaType(dataType), tmpResult) + val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 44d54a20844a3..cfec7f82951a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -442,9 +442,9 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") + val c = ctx.addMutableState(cal, "cal", + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = $c.get($cal.DAY_OF_WEEK); @@ -484,18 +484,17 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, + val c = ctx.addMutableState(cal, "cal", v => s""" - $c = $cal.getInstance($dtu.getTimeZone("UTC")); - $c.setFirstDayOfWeek($cal.MONDAY); - $c.setMinimalDaysInFirstWeek(4); - """) + |$v = $cal.getInstance($dtu.getTimeZone("UTC")); + |$v.setFirstDayOfWeek($cal.MONDAY); + |$v.setMinimalDaysInFirstWeek(4); + """.stripMargin) s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.WEEK_OF_YEAR); - """ + |$c.setTimeInMillis($time * 1000L * 3600L * 24L); + |${ev.value} = $c.get($cal.WEEK_OF_YEAR); + """.stripMargin }) } } @@ -1014,12 +1013,12 @@ case class FromUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} @@ -1190,12 +1189,12 @@ case class ToUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index cd38783a731ad..1cd73a92a8635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -199,8 +199,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. - val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") + val rowData = ctx.addMutableState("InternalRow[]", "rows", + v => s"$v = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => @@ -217,7 +217,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") + v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 294cdcb2e9546..b4f895fffda38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.freshName("coalesceTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4bd395eadcf19..a59aad5be8715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -62,15 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull) + val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue) + val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") argValue } @@ -548,7 +546,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue) + ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -644,7 +642,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -808,10 +806,11 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue) + ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue) + ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, + useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -844,7 +843,8 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -994,8 +994,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key) - ctx.addMutableState(valueElementJavaType, value) + ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) + ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1031,14 +1031,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1148,7 +1148,6 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1159,14 +1158,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """.stripMargin) // Code to serialize. val input = child.genCode(ctx) @@ -1194,7 +1193,6 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1205,14 +1203,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """.stripMargin) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index b4aefe6cff73e..8bc936fcbfc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -77,9 +77,8 @@ case class Rand(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" @@ -112,9 +111,8 @@ case class Randn(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 53d7096dd87d3..fa5425c77ebba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -112,15 +112,15 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + // inline mutable state since not many Like operations in a task + val pattern = ctx.addMutableState(patternClass, "patternLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -139,6 +139,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi """) } } else { + val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -187,15 +188,15 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + // inline mutable state since not many RLike operations in a task + val pattern = ctx.addMutableState(patternClass, "patternRLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -215,6 +216,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } } else { val rightStr = ctx.freshName("rightStr") + val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = ${eval2}.toString(); @@ -316,11 +318,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") val termResult = ctx.freshName("termResult") val classNamePattern = classOf[Pattern].getCanonicalName @@ -328,11 +325,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val matcher = ctx.freshName("matcher") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val termLastReplacement = ctx.addMutableState("String", "lastReplacement") + val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -414,14 +410,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8c4d2fd686be5..c02c41db1668e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -291,8 +291,7 @@ case class Elt(children: Seq[Expression]) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.freshName("stringVal") - ctx.addMutableState(ctx.javaType(dataType), stringVal) + val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" @@ -532,14 +531,11 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastMatching = ctx.freshName("lastMatching") - val termLastReplace = ctx.freshName("lastReplace") - val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") - ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") + val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching") + val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace") + val termDict = ctx.addMutableState(classNameDict, "dict") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { @@ -2065,15 +2061,12 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val lastDValue = ctx.freshName("lastDValue") - val pattern = ctx.freshName("pattern") - val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, - s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") + val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val numberFormat = ctx.addMutableState(df, "numberFormat", + v => s"""$v = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index be638d80e45d8..6edb4348f8309 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -348,10 +348,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22704: Least and greatest use less global variables") { val ctx1 = new CodegenContext() Least(Seq(Literal(1), Literal(1))).genCode(ctx1) - assert(ctx1.mutableStates.size == 1) + assert(ctx1.inlinedMutableStates.size == 1) val ctx2 = new CodegenContext() Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) - assert(ctx2.mutableStates.size == 1) + assert(ctx2.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 65617be05a434..1dd040e4696a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -851,6 +851,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext cast("1", IntegerType).genCode(ctx) cast("2", LongType).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a969811019161..b1a44528e64d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -385,20 +385,43 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext val schema = new StructType().add("a", IntegerType).add("b", StringType) CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22696: InitializeJavaBean should not use global variables") { val ctx = new CodegenContext InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), Map("add" -> Literal(1))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22716: addReferenceObj should not add mutable states") { val ctx = new CodegenContext val foo = new Object() ctx.addReferenceObj("foo", foo) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) + } + + test("SPARK-18016: define mutable states by using an array") { + val ctx1 = new CodegenContext + for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { + ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + } + assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + // When the number of primitive type mutable states is over the threshold, others are + // allocated into an array + assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) + + val ctx2 = new CodegenContext + for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) { + ctx2.addMutableState("InternalRow[]", "r", v => s"$v = new InternalRow[$i];") + } + // When the number of non-primitive type mutable states is over the threshold, others are + // allocated into a new array + assert(ctx2.inlinedMutableStates.isEmpty) + assert(ctx2.arrayCompactedMutableStates.get("InternalRow[]").get.arrayNames.size == 2) + assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10) + assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 6dfca7d73a3df..84190f0bd5f7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -304,6 +304,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: CreateNamedStruct should not use global variables") { val ctx = new CodegenContext CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 60d84aae1fa3f..a099119732e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -150,6 +150,6 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22705: case when should use less global variables") { val ctx = new CodegenContext() CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index a23cd95632770..cc6c15cb2c909 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -159,7 +159,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: Coalesce should use less global variables") { val ctx = new CodegenContext() Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } test("AtLeastNNonNulls should not throw 64kb exception") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 15cb0bea08f17..8a8f8e10225fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -249,7 +249,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("INSET") { @@ -440,6 +440,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 4fa61fbaf66c2..2a0a42c65b086 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -183,8 +183,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) - // are always required - assert(ctx.mutableStates.length == 4) + // are always required, which are allocated in type-based global array + assert(ctx.inlinedMutableStates.length == 0) + assert(ctx.mutableStateInitCode.length == 4) } test("RegexExtract") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 70dea4b39d55d..10e3ffd0dff97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,6 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 6031bdf19e957..2c45b3b0c73d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -219,4 +219,31 @@ class GeneratedProjectionSuite extends SparkFunSuite { // - one is the mutableRow assert(globalVariables.length == 3) } + + test("SPARK-18016: generated projections on wider table requiring state compaction") { + val N = 6000 + val wideRow1 = new GenericInternalRow(new Array[Any](N)) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + Array.tabulate[Any](N)(i => UTF8String.fromString(i.toString))) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(nested) + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e3675367d78e4..0d11958876ce9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -168,7 +168,7 @@ class ComplexTypesSuite extends PlanTest{ test("SPARK-22570: CreateArray should not create a lot of global variables") { val ctx = new CodegenContext CreateArray(Seq(Literal(1))).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } test("simplify map ops") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbdea..782cec5e292ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -68,30 +68,26 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { */ // TODO: return ColumnarBatch.Rows instead override protected def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", + v => s"$v = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName - val batch = ctx.freshName("batch") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.freshName("batchIdx") - ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;") - val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( - Seq.fill(colVars.size)(classOf[ColumnVector].getName)) - val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { - case ((name, columnVectorClz), i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ($columnVectorClz) $batch.column($i);" - } + Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) + val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { + case (columnVectorClz, i) => + val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") + (name, s"$name = ($columnVectorClz) $batch.column($i);") + }.unzip val nextBatch = ctx.freshName("nextBatch") val nextBatchFuncName = ctx.addNewFunction(nextBatch, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 747749bc72e66..4c3b1c49f703d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -110,8 +110,7 @@ case class RowDataSourceScanExec( override protected def doProduce(ctx: CodegenContext): String = { val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -353,8 +352,7 @@ case class FileSourceScanExec( } val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index c0e21343ae623..daff3c49e7517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -133,20 +133,18 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.freshName("needToSort") - ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;") + val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - sorterVariable = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, - s"$sorterVariable = $thisPlan.createSorter();") - val metrics = ctx.freshName("metrics") - ctx.addMutableState(classOf[TaskMetrics].getName, metrics, - s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") - val sortedIterator = ctx.freshName("sortedIter") - ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + // inline mutable state since not many Sort operations in a task + sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", + v => s"$v = $thisPlan.createSorter();", forceInline = true) + val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", + forceInline = true) val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 7166b7771e4db..9e7008d1e0c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,9 +282,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp } override def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one input RDD. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + // inline mutable state since an inputAdaptor in a task + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + forceInline = true) val row = ctx.freshName("row") s""" | while ($input.hasNext() && !stopEarly()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e72..b1af360d85095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,8 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -187,10 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -568,8 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -583,42 +579,41 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) // Create a name for the iterator from the fast hash map. - val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") - if (isFastHashMapEnabled) { + val iterTermForFastHashMap = if (isFastHashMapEnabled) { // Generates the fast hash map class and creates the fash hash map term. - fastHashMapTerm = ctx.freshName("fastHashMap") val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", + v => s"$v = new $fastHashMapClassName();") + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName(" + + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", + v => s"$v = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap) + "fastHashMapIter") } } // Create a name for the iterator from the regular hash map. - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) + // inline mutable state since not many aggregation operations in a task + val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, + "mapIter", forceInline = true) // create hashMap - hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") - sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) + hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", + v => s"$v = $thisPlan.createHashMap();") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", + forceInline = true) val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") @@ -758,8 +753,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 85b4529501ea8..1c613b19c4ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -46,10 +46,8 @@ abstract class HashMapGenerator( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c9a15147e30d0..78137d3f97cfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,29 +279,30 @@ case class SampleExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val sampler = ctx.freshName("sampler") if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - val initSamplerFuncName = ctx.addNewFunction(initSampler, - s""" - | private void $initSampler() { - | $sampler = new $samplerClass($upperBound - $lowerBound, false); - | java.util.Random random = new java.util.Random(${seed}L); - | long randomSeed = random.nextLong(); - | int loopCount = 0; - | while (loopCount < partitionIndex) { - | randomSeed = random.nextLong(); - | loopCount += 1; - | } - | $sampler.setSeed(randomSeed); - | } - """.stripMargin.trim) - - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSamplerFuncName();") + // inline mutable state since not many Sample operations in a task + val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", + v => { + val initSamplerFuncName = ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $v = new $samplerClass($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $v.setSeed(randomSeed); + | } + """.stripMargin.trim) + s"$initSamplerFuncName();" + }, forceInline = true) val samplingCount = ctx.freshName("samplingCount") s""" @@ -313,10 +314,10 @@ case class SampleExec( """.stripMargin.trim } else { val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName - ctx.addMutableState(s"$samplerClass", sampler, - s""" - | $sampler = new $samplerClass($lowerBound, $upperBound, false); - | $sampler.setSeed(${seed}L + partitionIndex); + val sampler = ctx.addMutableState(s"$samplerClass", "sampler", + v => s""" + | $v = new $samplerClass($lowerBound, $upperBound, false); + | $v.setSeed(${seed}L + partitionIndex); """.stripMargin.trim) s""" @@ -363,20 +364,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.freshName("initRange") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;") - val number = ctx.freshName("number") - ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;") + val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - val taskContext = ctx.freshName("taskContext") - ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") - val inputMetrics = ctx.freshName("inputMetrics") - ctx.addMutableState("InputMetrics", inputMetrics, - s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") + // inline mutable state since not many Range operations in a task + val taskContext = ctx.addMutableState("TaskContext", "taskContext", + v => s"$v = TaskContext.get();", forceInline = true) + val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", + v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true) // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -386,12 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;") + val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;") + val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") @@ -440,10 +437,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } """.stripMargin) - val input = ctx.freshName("input") - // Right now, Range is only used when there is one upstream. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff5dd707f0b38..4f28eeb725cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -70,7 +70,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val ctx = newCodeGenContext() val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => - val accessorName = ctx.freshName("accessor") val accessorCls = dt match { case NullType => classOf[NullColumnAccessor].getName case BooleanType => classOf[BooleanColumnAccessor].getName @@ -89,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName) + val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index c96ed6ef41016..ee763e23415cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -134,19 +134,18 @@ case class BroadcastHashJoinExec( // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - val addTaskListener = genTaskListener(avgHashProbe, relationTerm) - ctx.addMutableState(clsName, relationTerm, - s""" - | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); - | $addTaskListener - """.stripMargin) + // inline mutable state since not many join operations in a task + val relationTerm = ctx.addMutableState(clsName, "relation", + v => s""" + | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); + | incPeakExecutionMemory($v.estimatedSize()); + | ${genTaskListener(avgHashProbe, v)} + """.stripMargin, forceInline = true) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 554b73181116c..073730462a75f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,10 +422,9 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow) - val rightRow = ctx.freshName("rightRow") - ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + // inline mutable state since not many join operations in a task + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -436,14 +435,13 @@ case class SortMergeJoinExec( val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) // A list to hold all matched rows from right side. - val matches = ctx.freshName("matches") val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, - s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -578,10 +576,11 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - val leftInput = ctx.freshName("leftInput") - ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") - val rightInput = ctx.freshName("rightInput") - ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + // inline mutable state since not many join operations in a task + val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + v => s"$v = inputs[0];", forceInline = true) + val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + v => s"$v = inputs[1];", forceInline = true) val (leftRow, matches) = genScanner(ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index a8556f6ba107a..cccee63bc0680 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,8 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;") + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -80,8 +79,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.freshName("count") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1;